From 7448f5161303a2d0b2f7988cf59d80ef2110460a Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Fri, 3 Oct 2025 13:51:59 -0700 Subject: [PATCH 01/19] Use connection_factory parameter for SyncEntraConnection --- .../psycopg2/psycopg2_entra_id_extension.py | 181 ++++++++---------- 1 file changed, 80 insertions(+), 101 deletions(-) diff --git a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py index 4eeb1a4..6cfee08 100644 --- a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py +++ b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py @@ -1,110 +1,89 @@ # Copyright (c) Microsoft. All rights reserved. -""" -Connection classes for using Entra auth with Azure DB for PostgreSQL (psycopg2 + aiopg version). -This module provides both synchronous and asynchronous connection classes that allow you to connect to Azure DB for PostgreSQL -using Entra authentication. Uses psycopg2 for sync connections and aiopg for async connections. - -Sync Example (psycopg2): - from azurepg_entra.psycopg2 import connect_with_entra - - conn = connect_with_entra(host="myserver.postgres.database.azure.com", dbname="mydatabase") - -Async Example (aiopg): - from azurepg_entra.psycopg2 import connect_with_entra_async - - conn = await connect_with_entra_async(host="myserver.postgres.database.azure.com", dbname="mydatabase") - -Note: Async functionality requires aiopg: pip install aiopg -""" - -from typing import Any, Optional, TYPE_CHECKING, cast - +import sys +from psycopg2.extensions import connection import psycopg2 +import asyncio +import aiopg from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async -if TYPE_CHECKING: - import aiopg -else: - try: - import aiopg - except ImportError: - aiopg = None - -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential -from azure.identity import DefaultAzureCredential as DefaultAzureCredential -from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential - -AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" -AZURE_MANAGEMENT_SCOPE = "https://management.azure.com/.default" - -def connect_with_entra(credential: Optional[TokenCredential] = None, **kwargs: Any) -> psycopg2.extensions.connection: - """Creates a synchronous PostgreSQL connection using Entra authentication. - - This function handles Azure Entra ID token acquisition and creates a psycopg2 connection - with the appropriate user and password parameters. - - Parameters: - credential (TokenCredential, optional): The credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional connection parameters (host, port, dbname, etc.) - - Returns: - psycopg2.extensions.connection: An open synchronous connection to PostgreSQL. - - Raises: - ValueError: If the provided credential is not a valid TokenCredential. - """ - credential = credential or DefaultAzureCredential() - if credential and not isinstance(credential, TokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous connections") +# Define a custom connection class +class SyncEntraConnection(connection): + def __init__(self, dsn, **kwargs): + # Get Entra credentials before establishing connection + entra_creds = get_entra_conninfo(None) + + # Extract current DSN params and update with Entra credentials + from psycopg2.extensions import parse_dsn, make_dsn + dsn_params = parse_dsn(dsn) if dsn else {} + dsn_params.update(entra_creds) # This should include 'user' and 'password' + + # Create new DSN with Entra credentials + new_dsn = make_dsn(**dsn_params) + + # Call parent constructor with updated DSN + super().__init__(new_dsn, **kwargs) + + def cursor(self, *args, **kwargs): + return super().cursor(*args, **kwargs) - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = get_entra_conninfo(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] - - return cast(psycopg2.extensions.connection, psycopg2.connect(**kwargs)) - -async def connect_with_entra_async(credential: Optional[AsyncTokenCredential] = None, **kwargs: Any) -> aiopg.Connection: - """Creates an asynchronous PostgreSQL connection using Entra authentication. - - This function handles Azure Entra ID token acquisition and creates an aiopg connection - with the appropriate user and password parameters. - - Parameters: - credential (AsyncTokenCredential, optional): The async credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional connection parameters (host, port, dbname, etc.) - - Returns: - aiopg connection: An open asynchronous connection to PostgreSQL. - - Raises: - ImportError: If aiopg is not installed. - ValueError: If the provided credential is not a valid AsyncTokenCredential. - """ - if aiopg is None: - raise ImportError( - "aiopg is required for async connections. Install with: pip install aiopg" - ) +# For async, we need a different approach - use a factory function +async def create_async_entra_connection(**conn_params): + # Get Entra credentials asynchronously + entra_creds = await get_entra_conninfo_async(None) - credential = credential or AsyncDefaultAzureCredential() - if credential and not isinstance(credential, AsyncTokenCredential): - raise ValueError("credential must be an AsyncTokenCredential for async connections") + # Update connection parameters with Entra credentials + conn_params.update(entra_creds) - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = await get_entra_conninfo_async(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] + # Create connection with updated parameters + conn = await aiopg.connect(**conn_params) + return conn - return await aiopg.connect(**kwargs) \ No newline at end of file +# Define a custom connection class +# class AsyncEntraConnection(connection): +# async def __init__(self, dsn, **kwargs): +# # Get Entra credentials before establishing connection +# entra_creds = await get_entra_conninfo_async() + +# # Extract current DSN params and update with Entra credentials +# from psycopg2.extensions import parse_dsn, make_dsn +# dsn_params = parse_dsn(dsn) if dsn else {} +# dsn_params.update(entra_creds) # This should include 'user' and 'password' + +# # Create new DSN with Entra credentials +# new_dsn = make_dsn(**dsn_params) + +# # Call parent constructor with updated DSN +# super().__init__(new_dsn, **kwargs) + +# def cursor(self, *args, **kwargs): +# return super().cursor(*args, **kwargs) + +def sync_test(): + # Use it as a factory + conn = psycopg2.connect( + dbname="postgres", + host="pg-mjm-dev1.postgres.database.azure.com", + connection_factory=SyncEntraConnection + ) + cur = conn.cursor() + cur.execute("SELECT 1") + print(cur.fetchone()) + +async def async_test(): + # Use the factory function instead + conn = await create_async_entra_connection( + dbname="postgres", + host="pg-mjm-dev1.postgres.database.azure.com" + ) + cur = await conn.cursor() + await cur.execute("SELECT 1") + result = await cur.fetchone() + print(result) + conn.close() + +if __name__ == "__main__": + sync_test() + if sys.platform == 'win32': + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(async_test()) \ No newline at end of file From 004fb8cd2890a327bd054281657c039d54cd5ac4 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Fri, 3 Oct 2025 13:59:50 -0700 Subject: [PATCH 02/19] Remove user-defining creds --- .../psycopg2/psycopg2_entra_id_extension.py | 34 +------------------ 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py index 6cfee08..f1f7445 100644 --- a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py +++ b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py @@ -1,8 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import sys from psycopg2.extensions import connection -import psycopg2 -import asyncio import aiopg from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async @@ -57,33 +54,4 @@ async def create_async_entra_connection(**conn_params): # super().__init__(new_dsn, **kwargs) # def cursor(self, *args, **kwargs): -# return super().cursor(*args, **kwargs) - -def sync_test(): - # Use it as a factory - conn = psycopg2.connect( - dbname="postgres", - host="pg-mjm-dev1.postgres.database.azure.com", - connection_factory=SyncEntraConnection - ) - cur = conn.cursor() - cur.execute("SELECT 1") - print(cur.fetchone()) - -async def async_test(): - # Use the factory function instead - conn = await create_async_entra_connection( - dbname="postgres", - host="pg-mjm-dev1.postgres.database.azure.com" - ) - cur = await conn.cursor() - await cur.execute("SELECT 1") - result = await cur.fetchone() - print(result) - conn.close() - -if __name__ == "__main__": - sync_test() - if sys.platform == 'win32': - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(async_test()) \ No newline at end of file +# return super().cursor(*args, **kwargs) \ No newline at end of file From e7f836bf64a3f095fd9d0025ccf9796b1dd19115 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Sun, 5 Oct 2025 14:53:40 -0700 Subject: [PATCH 03/19] Update Python support (source code, sample program, tests) --- python/pyproject.toml | 7 +- .../getting_started/async_pool_utils.py | 118 --- .../create_db_connection_psycopg2.py | 121 +-- .../create_db_connection_psycopg3.py | 25 +- .../create_db_connection_sqlalchemy.py | 155 +--- python/src/azurepg_entra/core.py | 26 +- python/src/azurepg_entra/psycopg2/__init__.py | 43 +- .../psycopg2/psycopg2_entra_id_extension.py | 79 +- python/src/azurepg_entra/psycopg3/__init__.py | 1 - .../psycopg3/psycopg3_entra_id_extension.py | 22 +- .../src/azurepg_entra/sqlalchemy/__init__.py | 31 +- .../sqlalchemy_entra_id_extension.py | 243 ++---- .../test_psycopg2_entra_id_extension.py | 468 +--------- .../test_psycopg3_entra_id_extension.py | 453 +--------- .../test_sqlalchemy_entra_id_extension.py | 807 +++--------------- .../postgresql/test_core_functionality.py | 90 ++ 16 files changed, 541 insertions(+), 2148 deletions(-) delete mode 100644 python/samples/psycopg2/getting_started/async_pool_utils.py create mode 100644 python/tests/azure/data/postgresql/test_core_functionality.py diff --git a/python/pyproject.toml b/python/pyproject.toml index c04baea..70bc2f1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,25 +20,24 @@ classifiers = [ dependencies = [ "azure-identity>=1.13.0", "azure-core>=1.24.0", + "aiohttp>=3.8.0", + "PyJWT>=2.0.0" ] [project.optional-dependencies] # psycopg3 support psycopg3 = [ "psycopg[binary]>=3.1.0", - "psycopg-pool>=3.1.0" ] # psycopg2 support psycopg2 = [ - "psycopg2-binary>=2.9.0", - "aiopg>=1.4.0" + "psycopg2-binary>=2.9.0" ] # SQLAlchemy support sqlalchemy = [ "sqlalchemy>=2.0.0", - "psycopg[binary]>=3.1.0" ] # Development dependencies diff --git a/python/samples/psycopg2/getting_started/async_pool_utils.py b/python/samples/psycopg2/getting_started/async_pool_utils.py deleted file mode 100644 index ac8661e..0000000 --- a/python/samples/psycopg2/getting_started/async_pool_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Async connection pool utilities for psycopg2/aiopg with connection factory support. - -This module provides AsyncEntraConnectionPool, a custom async connection pool -that mimics psycopg2.pool.ThreadedConnectionPool's connection_factory pattern -for asynchronous connections. -""" - -import asyncio -from typing import Callable, Any, Awaitable - - -class AsyncEntraConnectionPool: - """ - Custom async connection pool that supports connection_factory pattern. - - Mimics psycopg2.pool.ThreadedConnectionPool API but for async connections. - This is needed because: - 1. psycopg2 pools are sync-only - 2. aiopg.create_pool doesn't support connection_factory parameter - - Usage: - async def my_connection_factory(): - return await connect_with_entra_async(host="...", dbname="...") - - async with AsyncEntraConnectionPool(my_connection_factory, minconn=1, maxconn=5) as pool: - conn = await pool.getconn() - try: - # Use connection - pass - finally: - pool.putconn(conn) - """ - - def __init__(self, connection_factory: Callable[[], Awaitable[Any]], minconn: int = 1, maxconn: int = 5): - """ - Initialize the async connection pool. - - Args: - connection_factory: Async function that creates and returns a new connection - minconn: Minimum number of connections to maintain in the pool - maxconn: Maximum number of connections allowed - """ - self.connection_factory = connection_factory - self.minconn = minconn - self.maxconn = maxconn - self._pool = [] - self._used = set() - self._lock = asyncio.Lock() - self._closed = False - - async def __aenter__(self): - """Context manager entry - pre-populate with minimum connections.""" - for _ in range(self.minconn): - conn = await self.connection_factory() - self._pool.append(conn) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - close all connections.""" - await self.closeall() - - async def getconn(self): - """ - Get a connection from the pool. - - Returns: - Connection object from the pool or newly created - - Raises: - Exception: If pool is exhausted (all maxconn connections in use) - """ - if self._closed: - raise Exception("Connection pool is closed") - - async with self._lock: - if self._pool: - conn = self._pool.pop() - elif len(self._used) < self.maxconn: - conn = await self.connection_factory() - else: - raise Exception("Connection pool exhausted") - - self._used.add(conn) - return conn - - def putconn(self, conn): - """ - Return a connection to the pool. - - Args: - conn: Connection to return to the pool - """ - if self._closed: - return - - if conn in self._used: - self._used.remove(conn) - if len(self._pool) < self.minconn: - self._pool.append(conn) - else: - # Pool is full, close the connection - conn.close() - - async def closeall(self): - """Close all connections in the pool.""" - self._closed = True - - # Close all pooled connections - for conn in self._pool: - conn.close() - - # Close all connections in use - for conn in list(self._used): - conn.close() - - self._pool.clear() - self._used.clear() \ No newline at end of file diff --git a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py index 319371c..1a94717 100644 --- a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py +++ b/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py @@ -1,23 +1,11 @@ """ -Sample demonstrating both synchronous and asynchronous psycopg2 connections -with Azure Entra ID authentication for Azure PostgreSQL. - -This example shows: -1. Synchronous connection using psycopg2 with Entra ID authentication -2. Asynchronous connection using aiopg with Entra ID authentication - -Both examples use the same Azure Entra ID authentication mechanism to connect -to Azure Database for PostgreSQL. +Sample demonstrating psycopg2 connection with synchronous Entra ID authentication for Azure PostgreSQL. """ from dotenv import load_dotenv -import argparse -import asyncio -import sys import os from psycopg2 import pool -from azurepg_entra.psycopg2 import connect_with_entra, connect_with_entra_async -from async_pool_utils import AsyncEntraConnectionPool +from azurepg_entra.psycopg2 import SyncEntraConnection # Load environment variables from .env file load_dotenv() @@ -25,23 +13,16 @@ DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") def main_sync(): - """Synchronous connection example using psycopg2 with Entra ID authentication and connection pooling.""" - try: - # Create a wrapper function that explicitly passes our server parameters - def entra_connection_factory(*args, **kwargs): - # Ignore any arguments passed by ThreadedConnectionPool and use our explicit parameters - return connect_with_entra( - host=SERVER, - port=5432, - dbname=DATABASE - ) - - # Create a connection pool using psycopg2 with Entra ID authentication + # We pass in the SyncEntraConnection class to enable Entra authentication for the + # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using + # the token itself (with the PostgreSQL scope) as the password. connection_pool = pool.ThreadedConnectionPool( minconn=1, maxconn=5, - connection_factory=entra_connection_factory + host=SERVER, + database=DATABASE, + connection_factory=SyncEntraConnection ) # Get a connection from the pool @@ -49,14 +30,15 @@ def entra_connection_factory(*args, **kwargs): try: with conn.cursor() as cur: + # Query 1 cur.execute("SELECT now()") result = cur.fetchone() - print(f"Sync - Database time: {result[0]}") + print(f"Database time: {result[0]}") - # Test current user query + # Query 2 cur.execute("SELECT current_user") user = cur.fetchone() - print(f"Sync - Connected as: {user[0]}") + print(f"Connected as: {user[0]}") finally: # Return connection to pool connection_pool.putconn(conn) @@ -66,82 +48,5 @@ def entra_connection_factory(*args, **kwargs): print(f"Sync - Error connecting to database: {e}") raise -async def main_async(): - """Asynchronous connection example with custom async pool using connection_factory pattern.""" - - try: - # Create async connection factory function (mirrors the sync version) - async def entra_async_connection_factory(*args, **kwargs): - # Ignore any arguments and use our explicit parameters - return await connect_with_entra_async( - host=SERVER, - port=5432, - dbname=DATABASE - ) - - # Use our custom async pool with connection factory - async with AsyncEntraConnectionPool(entra_async_connection_factory, minconn=1, maxconn=5) as connection_pool: - # Get a connection from the pool (mirrors sync pattern) - conn = await connection_pool.getconn() - - try: - async with conn.cursor() as cur: - await cur.execute("SELECT now()") - result = await cur.fetchone() - print(f"Async - Database time: {result[0]}") - - # Test current user query - await cur.execute("SELECT current_user") - user = await cur.fetchone() - print(f"Async - Connected as: {user[0]}") - finally: - # Return connection to pool (mirrors sync pattern) - connection_pool.putconn(conn) - - except Exception as e: - print(f"Async - Error connecting to database: {e}") - raise - -async def main(mode: str = "both"): - """Main function that runs sync and/or async examples based on mode. - - Args: - mode: "sync", "async", or "both" to determine which examples to run - """ - if mode in ("sync", "both"): - print("=== Running Synchronous Example ===") - try: - main_sync() - print("✅ Sync example completed successfully!") - except Exception as e: - print(f"❌ Sync example failed: {e}") - - if mode in ("async", "both"): - if mode == "both": - print("\n=== Running Asynchronous Example ===") - else: - print("=== Running Asynchronous Example ===") - try: - await main_async() - print("✅ Async example completed successfully!") - except Exception as e: - print(f"❌ Async example failed: {e}") - if __name__ == "__main__": - # Parse command line arguments - parser = argparse.ArgumentParser( - description="Demonstrate psycopg2/aiopg connections with Azure Entra ID authentication" - ) - parser.add_argument( - "--mode", - choices=["sync", "async", "both"], - default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" - ) - args = parser.parse_args() - - # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main(args.mode)) \ No newline at end of file + main_sync() \ No newline at end of file diff --git a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py index 9d95eb0..6445189 100644 --- a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py +++ b/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py @@ -1,13 +1,6 @@ """ Sample demonstrating both synchronous and asynchronous psycopg connections with Azure Entra ID authentication for Azure PostgreSQL. - -This example shows: -1. Synchronous connection using SyncEntraConnection and ConnectionPool -2. Asynchronous connection using AsyncEntraConnection and AsyncConnectionPool - -Both examples use the same Azure Entra ID authentication mechanism to connect -to Azure Database for PostgreSQL. """ from psycopg_pool import AsyncConnectionPool, ConnectionPool @@ -27,6 +20,9 @@ def main_sync(): """Synchronous connection example using psycopg with Entra ID authentication.""" try: + # We pass in the SyncEntraConnection class to enable Entra authentication for the + # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using + # the token itself (with the PostgreSQL scope) as the password. pool = ConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, @@ -36,11 +32,12 @@ def main_sync(): ) pool.open() with pool, pool.connection() as conn, conn.cursor() as cur: + # Query 1 cur.execute("SELECT now()") result = cur.fetchone() print(f"Sync - Database time: {result}") - # Test current user query + # Query 2 cur.execute("SELECT current_user") user = cur.fetchone() print(f"Sync - Connected as: {user[0]}") @@ -52,20 +49,24 @@ async def main_async(): """Asynchronous connection example using psycopg with Entra ID authentication.""" try: + # We pass in the AsyncEntraConnection class to enable Entra authentication for the + # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using + # the token itself (with the PostgreSQL scope) as the password. pool = AsyncConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, max_size=5, open=False, - connection_class=AsyncEntraConnection + connection_class=AsyncEntraConnection ) await pool.open() async with pool, pool.connection() as conn, conn.cursor() as cur: + # Query 1 await cur.execute("SELECT now()") result = await cur.fetchone() print(f"Async - Database time: {result}") - - # Test current user query + + # Query 2 await cur.execute("SELECT current_user") user = await cur.fetchone() print(f"Async - Connected as: {user[0]}") @@ -73,7 +74,7 @@ async def main_async(): print(f"Async - Error connecting to database: {e}") raise -async def main(mode: str = "both"): +async def main(mode: str = "async"): """Main function that runs sync and/or async examples based on mode. Args: diff --git a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py index 918e2d7..56ce94e 100644 --- a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py +++ b/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py @@ -1,25 +1,19 @@ """ Sample demonstrating both synchronous and asynchronous SQLAlchemy connections with Azure Entra ID authentication for Azure PostgreSQL. - -This example shows: -1. Synchronous connection using create_engine_with_entra and connection pooling -2. Asynchronous connection using create_async_engine_with_entra and async connection pooling - -Both examples use the same Azure Entra ID authentication mechanism to connect -to Azure Database for PostgreSQL. """ +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine +from azurepg_entra.sqlalchemy import enable_entra_authentication, enable_entra_authentication_async from dotenv import load_dotenv import argparse import asyncio import sys import os -from sqlalchemy import text -from azurepg_entra.sqlalchemy import create_engine_with_entra, create_async_engine_with_entra # Load environment variables from .env file -load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), '.env')) +load_dotenv() SERVER = os.getenv("POSTGRES_SERVER") DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") @@ -27,34 +21,26 @@ def main_sync(): """Synchronous connection example using SQLAlchemy with Entra ID authentication.""" try: - # Create engine with Entra ID authentication and connection pooling - engine = create_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=5, - max_overflow=10, - pool_pre_ping=True, # Validate connections before use - echo=False # Set to True to see SQL queries - ) + # Create a synchronous engine + engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") - # Test connection + # We add an event listener to the engine to enable Entra authentication for the + # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using + # the token itself (with the PostgreSQL scope) as the password. This event listener is triggered + # whenever we get a NEW connection from the pool backing the engine. + enable_entra_authentication(engine) + with engine.connect() as conn: + # Query 1 result = conn.execute(text("SELECT now()")) - db_time = result.fetchone() - print(f"Sync - Database time: {db_time[0]}") + print(f"Sync - Database time: {result.fetchone()[0]}") - # Test current user query + # Query 2 result = conn.execute(text("SELECT current_user")) - user = result.fetchone() - print(f"Sync - Connected as: {user[0]}") + print(f"Sync - Connected as: {result.fetchone()[0]}") - # Test a simple query to verify functionality - result = conn.execute(text("SELECT 'SQLAlchemy Sync Entra Connection Working!' as message")) - message = result.fetchone() - print(f"Sync - Test message: {message[0]}") - - # Clean up the sync engine + # Clean up the engine engine.dispose() - except Exception as e: print(f"Sync - Error connecting to database: {e}") raise @@ -63,123 +49,52 @@ async def main_async(): """Asynchronous connection example using SQLAlchemy with Entra ID authentication.""" try: - # Create async engine with Entra ID authentication and connection pooling - engine = create_async_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=5, - max_overflow=10, - pool_pre_ping=True, # Validate connections before use - echo=False # Set to True to see SQL queries - ) + # Create an asynchronous engine + engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable Entra authentication for the + # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using + # the token itself (with the PostgreSQL scope) as the password. This event listener is triggered + # whenever we get a NEW connection from the pool backing the engine. + enable_entra_authentication_async(engine) - # Test async connection async with engine.connect() as conn: + # Query 1 result = await conn.execute(text("SELECT now()")) - db_time = result.fetchone() - print(f"Async - Database time: {db_time[0]}") - - # Test current user query + print(f"Async - Database time: {result.fetchone()[0]}") + + # Query 2 result = await conn.execute(text("SELECT current_user")) - user = result.fetchone() - print(f"Async - Connected as: {user[0]}") - - # Test a simple query to verify functionality - result = await conn.execute(text("SELECT 'SQLAlchemy Async Entra Connection Working!' as message")) - message = result.fetchone() - print(f"Async - Test message: {message[0]}") + print(f"Async - Connected as: {result.fetchone()[0]}") - # Clean up the async engine + # Clean up the engine await engine.dispose() - except Exception as e: print(f"Async - Error connecting to database: {e}") raise -def test_connection_pool_refresh(): - """Test that connection pool handles token refresh properly (sync version).""" - - try: - print("\n=== Testing Connection Pool Token Refresh (Sync) ===") - - engine = create_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=2, - max_overflow=0, - echo=False - ) - - # Make multiple connections to test token refresh - for i in range(3): - with engine.connect() as conn: - result = conn.execute(text("SELECT current_user, now()")) - user, db_time = result.fetchone() - print(f"Connection {i+1} - User: {user}, Time: {db_time}") - - # Clean up the sync engine - engine.dispose() - print("✅ Connection pool token refresh test completed successfully!") - - except Exception as e: - print(f"❌ Connection pool test failed: {e}") - raise - -async def test_async_connection_pool_refresh(): - """Test that async connection pool handles token refresh properly.""" - - try: - print("\n=== Testing Async Connection Pool Token Refresh ===") - - engine = create_async_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=2, - max_overflow=0, - echo=False - ) - - # Make multiple async connections to test token refresh - for i in range(3): - async with engine.connect() as conn: - result = await conn.execute(text("SELECT current_user, now()")) - user, db_time = result.fetchone() - print(f"Async Connection {i+1} - User: {user}, Time: {db_time}") - - await engine.dispose() - print("✅ Async connection pool token refresh test completed successfully!") - - except Exception as e: - print(f"❌ Async connection pool test failed: {e}") - raise - -async def main(mode: str = "both"): +async def main(mode: str = "async"): """Main function that runs sync and/or async examples based on mode. Args: mode: "sync", "async", or "both" to determine which examples to run """ if mode in ("sync", "both"): - print("=== Running Synchronous Example ===") + print("=== Running Synchronous SQLAlchemy Example ===") try: main_sync() print("✅ Sync example completed successfully!") - - # Test connection pool behavior - test_connection_pool_refresh() - except Exception as e: print(f"❌ Sync example failed: {e}") if mode in ("async", "both"): if mode == "both": - print("\n=== Running Asynchronous Example ===") + print("\n=== Running Asynchronous SQLAlchemy Example ===") else: - print("=== Running Asynchronous Example ===") + print("=== Running Asynchronous SQLAlchemy Example ===") try: await main_async() print("✅ Async example completed successfully!") - - # Test async connection pool behavior - await test_async_connection_pool_refresh() - except Exception as e: print(f"❌ Async example failed: {e}") diff --git a/python/src/azurepg_entra/core.py b/python/src/azurepg_entra/core.py index 76d11e0..626ce77 100644 --- a/python/src/azurepg_entra/core.py +++ b/python/src/azurepg_entra/core.py @@ -1,6 +1,5 @@ import logging -import json -import base64 +import jwt from typing import Any, cast from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential @@ -46,19 +45,22 @@ async def get_entra_token_async(credential: AsyncTokenCredential | None, scope: cred = await credential.get_token(scope) return cred.token -def decode_jwt(token: str) -> dict[str, Any]: +def decode_jwt(token: str) -> dict[str, Any] | None: """Decodes a JWT token to extract its payload claims. Parameters: token (str): The JWT token string in the standard three-part format. Returns: - dict: A dictionary containing the claims extracted from the token payload. + dict | None: A dictionary containing the claims extracted from the token payload, + or None if the token is invalid. """ - payload = token.split(".")[1] - padding = "=" * (4 - len(payload) % 4) - decoded_payload = base64.urlsafe_b64decode(payload + padding) - return cast(dict[str, Any], json.loads(decoded_payload)) + try: + # Decode without verification since we only need the payload claims + # Azure tokens are already validated by the credential provider + return cast(dict[str, Any], jwt.decode(token, options={"verify_signature": False})) + except Exception: + return None def parse_principal_name(xms_mirid: str) -> str | None: """Parses the principal name from an Azure resource path. @@ -104,6 +106,8 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: # Always get the DB-scope token for password db_token = get_entra_token(credential, AZURE_DB_FOR_POSTGRES_SCOPE) db_claims = decode_jwt(db_token) + if not db_claims: + raise ValueError("Invalid DB token format") xms_mirid = db_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None @@ -116,6 +120,8 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: # Fall back to management scope ONLY to discover username mgmt_token = get_entra_token(credential, AZURE_MANAGEMENT_SCOPE) mgmt_claims = decode_jwt(mgmt_token) + if not mgmt_claims: + raise ValueError("Invalid management token format") xms_mirid = mgmt_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None @@ -149,6 +155,8 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d db_token = await get_entra_token_async(credential, AZURE_DB_FOR_POSTGRES_SCOPE) db_claims = decode_jwt(db_token) + if not db_claims: + raise ValueError("Invalid DB token format") xms_mirid = db_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None @@ -160,6 +168,8 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d if not username: mgmt_token = await get_entra_token_async(credential, AZURE_MANAGEMENT_SCOPE) mgmt_claims = decode_jwt(mgmt_token) + if not mgmt_claims: + raise ValueError("Invalid management token format") xms_mirid = mgmt_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None diff --git a/python/src/azurepg_entra/psycopg2/__init__.py b/python/src/azurepg_entra/psycopg2/__init__.py index ddcdf20..e30a0b9 100644 --- a/python/src/azurepg_entra/psycopg2/__init__.py +++ b/python/src/azurepg_entra/psycopg2/__init__.py @@ -1,56 +1,43 @@ # Copyright (c) Microsoft. All rights reserved. """ -Psycopg2 + aiopg support for Azure Entra ID authentication with Azure Database for PostgreSQL. +Psycopg2 support for Azure Entra ID authentication with Azure Database for PostgreSQL. -This module provides connection functions that handle Azure Entra ID token acquisition -and authentication for both synchronous (psycopg2) and asynchronous (aiopg) PostgreSQL connections. +This module provides connection classes that handle Azure Entra ID token acquisition +and authentication for synchronous (psycopg2) PostgreSQL connections. Requirements: Install with: pip install azurepg-entra[psycopg2] This will install: - psycopg2-binary>=2.8.0 - - aiopg>=1.3.0 (for async support) -Functions: - connect_with_entra: Synchronous connection function with Entra ID authentication (psycopg2) - connect_with_entra_async: Asynchronous connection function with Entra ID authentication (aiopg) - get_entra_conninfo: Synchronous function to get Entra authentication info - get_entra_conninfo_async: Asynchronous function to get Entra authentication info +Classes: + SyncEntraConnection: Synchronous connection class with Entra ID authentication (psycopg2) Example usage: # Synchronous connection - from azurepg_entra.psycopg2 import connect_with_entra + from azurepg_entra.psycopg2 import SyncEntraConnection - conn = connect_with_entra( - host="myserver.postgres.database.azure.com", - dbname="mydatabase", - port=5432 - ) - - # Asynchronous connection - from azurepg_entra.psycopg2 import connect_with_entra_async - - conn = await connect_with_entra_async( - host="myserver.postgres.database.azure.com", - dbname="mydatabase", - port=5432 + connection_pool = pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=SERVER, + database=DATABASE, + connection_factory=SyncEntraConnection ) """ try: from .psycopg2_entra_id_extension import ( - connect_with_entra, - connect_with_entra_async + SyncEntraConnection, ) __all__ = [ - "connect_with_entra", - "connect_with_entra_async" + "SyncEntraConnection", ] except ImportError as e: - # Provide a helpful error message if psycopg2/aiopg dependencies are missing + # Provide a helpful error message if psycopg2 dependencies are missing raise ImportError( "psycopg2 dependencies are not installed. " "Install them with: pip install azurepg-entra[psycopg2]" diff --git a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py index f1f7445..dab59df 100644 --- a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py +++ b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py @@ -1,57 +1,44 @@ # Copyright (c) Microsoft. All rights reserved. -from psycopg2.extensions import connection -import aiopg - -from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async +from psycopg2.extensions import connection, parse_dsn, make_dsn +from azurepg_entra.core import get_entra_conninfo # Define a custom connection class class SyncEntraConnection(connection): + """Establishes a synchronous PostgreSQL connection using Entra authentication. + + The method checks for provided credentials. If the 'user' or 'password' are not set + in the DSN or keyword arguments, it acquires them from Entra via the provided or default credential. + + Parameters: + dsn: PostgreSQL connection string. + **kwargs: Keyword arguments including optional 'credential', and optionally 'user' and 'password'. + + Raises: + ValueError: If the provided credential is not a valid TokenCredential. + """ def __init__(self, dsn, **kwargs): - # Get Entra credentials before establishing connection - entra_creds = get_entra_conninfo(None) - - # Extract current DSN params and update with Entra credentials - from psycopg2.extensions import parse_dsn, make_dsn + # Extract current DSN params dsn_params = parse_dsn(dsn) if dsn else {} - dsn_params.update(entra_creds) # This should include 'user' and 'password' - # Create new DSN with Entra credentials - new_dsn = make_dsn(**dsn_params) + # Check if user and password are already provided + has_user = 'user' in dsn_params or 'user' in kwargs + has_password = 'password' in dsn_params or 'password' in kwargs - # Call parent constructor with updated DSN - super().__init__(new_dsn, **kwargs) - - def cursor(self, *args, **kwargs): - return super().cursor(*args, **kwargs) - -# For async, we need a different approach - use a factory function -async def create_async_entra_connection(**conn_params): - # Get Entra credentials asynchronously - entra_creds = await get_entra_conninfo_async(None) - - # Update connection parameters with Entra credentials - conn_params.update(entra_creds) - - # Create connection with updated parameters - conn = await aiopg.connect(**conn_params) - return conn - -# Define a custom connection class -# class AsyncEntraConnection(connection): -# async def __init__(self, dsn, **kwargs): -# # Get Entra credentials before establishing connection -# entra_creds = await get_entra_conninfo_async() + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + entra_creds = get_entra_conninfo(None) + + # Only update missing credentials + if not has_user and 'user' in entra_creds: + dsn_params['user'] = entra_creds['user'] + if not has_password and 'password' in entra_creds: + dsn_params['password'] = entra_creds['password'] -# # Extract current DSN params and update with Entra credentials -# from psycopg2.extensions import parse_dsn, make_dsn -# dsn_params = parse_dsn(dsn) if dsn else {} -# dsn_params.update(entra_creds) # This should include 'user' and 'password' + # Update DSN params with any kwargs (kwargs take precedence) + dsn_params.update(kwargs) -# # Create new DSN with Entra credentials -# new_dsn = make_dsn(**dsn_params) + # Create new DSN with updated credentials + new_dsn = make_dsn(**dsn_params) -# # Call parent constructor with updated DSN -# super().__init__(new_dsn, **kwargs) - -# def cursor(self, *args, **kwargs): -# return super().cursor(*args, **kwargs) \ No newline at end of file + # Call parent constructor with updated DSN only + super().__init__(new_dsn) \ No newline at end of file diff --git a/python/src/azurepg_entra/psycopg3/__init__.py b/python/src/azurepg_entra/psycopg3/__init__.py index 4eace8d..81a7a39 100644 --- a/python/src/azurepg_entra/psycopg3/__init__.py +++ b/python/src/azurepg_entra/psycopg3/__init__.py @@ -10,7 +10,6 @@ This will install: - psycopg[binary]>=3.1.0 - - psycopg-pool>=3.1.0 Classes: SyncEntraConnection: Synchronous connection class with Entra ID authentication diff --git a/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py b/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py index e32c4c6..2a7ade4 100644 --- a/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py +++ b/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py @@ -1,19 +1,4 @@ # Copyright (c) Microsoft. All rights reserved. -""" -Connection classes for using Entra auth with Azure DB for PostgreSQL. -This module provides both synchronous and asynchronous connection classes that allow you to connect to Azure DB for PostgreSQL -using Entra authentication with psycopg. It handles token acquisition and connection setup. It is not specific -to this repository and can be used in any project that requires Entra authentication with Azure DB for PostgreSQL. - -For example: - - from azure_pg_entra import AsyncEntraConnection - from psycopg_pool import AsyncConnectionPool - - async with AsyncConnectionPool("", connection_class=AsyncEntraConnection) as pool: - ... - -""" from typing import Any try: @@ -47,11 +32,8 @@ def connect(cls, *args: Any, **kwargs: Any) -> Self: ValueError: If the provided credential is not a valid TokenCredential. """ credential = kwargs.pop("credential", None) - if credential: - if isinstance(credential, AsyncTokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous connections") - if not isinstance(credential, TokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous connections") + if credential and not isinstance(credential, (TokenCredential)): + raise ValueError("credential must be a TokenCredential for sync connections") # Check if we need to acquire Entra authentication info if not kwargs.get("user") or not kwargs.get("password"): diff --git a/python/src/azurepg_entra/sqlalchemy/__init__.py b/python/src/azurepg_entra/sqlalchemy/__init__.py index 7e91176..c7922b5 100644 --- a/python/src/azurepg_entra/sqlalchemy/__init__.py +++ b/python/src/azurepg_entra/sqlalchemy/__init__.py @@ -1,14 +1,37 @@ # Copyright (c) Microsoft. All rights reserved. """ SQLAlchemy integration for Azure PostgreSQL with Entra ID authentication. + +This module provides seamless integration between SQLAlchemy and Azure Entra ID +authentication for PostgreSQL connections. It automatically handles token acquisition +and credential injection through SQLAlchemy's event system. + +Usage: + Synchronous engines: + from sqlalchemy import create_engine + from azurepg_entra.sqlalchemy import enable_entra_authentication + + engine = create_engine("postgresql://myserver.postgres.database.azure.com/mydb") + enable_entra_authentication(engine) + + Asynchronous engines: + from sqlalchemy.ext.asyncio import create_async_engine + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + + engine = create_async_engine("postgresql+asyncpg://myserver.postgres.database.azure.com/mydb") + enable_entra_authentication_async(engine) + +Functions: + enable_entra_authentication: Enable Entra ID auth for synchronous SQLAlchemy engines + enable_entra_authentication_async: Enable Entra ID auth for asynchronous SQLAlchemy engines """ from .sqlalchemy_entra_id_extension import ( - create_engine_with_entra, - create_async_engine_with_entra, + enable_entra_authentication, + enable_entra_authentication_async, ) __all__ = [ - "create_engine_with_entra", - "create_async_engine_with_entra", + "enable_entra_authentication", + "enable_entra_authentication_async", ] \ No newline at end of file diff --git a/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py b/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py index 030ce0f..81a3a21 100644 --- a/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py +++ b/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py @@ -1,169 +1,102 @@ # Copyright (c) Microsoft. All rights reserved. -import psycopg +import asyncio import logging -from typing import Optional, Any, TYPE_CHECKING -from urllib.parse import urlparse, urlunparse -from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async +import sys +from sqlalchemy import Engine, event +from sqlalchemy.ext.asyncio import AsyncEngine +from typing import Optional from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential -from azure.identity import DefaultAzureCredential as DefaultAzureCredential -from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential -from sqlalchemy.engine.interfaces import DBAPIConnection -from sqlalchemy.ext.asyncio.engine import AsyncEngine - -try: - from sqlalchemy import create_engine, Engine -except ImportError: - raise ImportError("sqlalchemy is required. Install with: pip install sqlalchemy") - -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import create_async_engine -else: - try: - from sqlalchemy.ext.asyncio import create_async_engine - except ImportError: - create_async_engine = None +from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def create_engine_with_entra( - url: str, - credential: Optional[TokenCredential] = None, - **kwargs: Any -) -> Engine: - """Creates a SQLAlchemy Engine using Entra authentication for Azure PostgreSQL. - - This function handles Azure Entra ID token acquisition and creates a SQLAlchemy engine - that automatically refreshes tokens for each new connection. This solves the token - expiration issue by acquiring fresh tokens on each connection attempt. - - Parameters: - url (str): The database URL. Username and password will be replaced with Entra credentials. - credential (TokenCredential, optional): The credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional engine creation parameters passed to create_engine() - - Returns: - Engine: A SQLAlchemy engine configured with Entra authentication and automatic token refresh. - - Raises: - ValueError: If the provided credential is not a valid TokenCredential. - - Example: - engine = create_engine_with_entra( - "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - ) +def enable_entra_authentication(engine: Engine, credential: Optional[TokenCredential] = None): """ - credential = credential or DefaultAzureCredential() - if credential and not isinstance(credential, TokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous engines") + Enable Azure Entra ID authentication for a SQLAlchemy engine. - # Parse the original URL to extract connection parameters - parsed = urlparse(url) + This function registers an event listener that automatically provides + Entra ID credentials for each database connection if they are not already set. - def connect_with_fresh_token() -> DBAPIConnection | None: - """Custom connection factory that gets a fresh token each time.""" - logger.info("Creating new connection with fresh Entra token") - - # Get fresh Entra authentication info for each connection - entra_conninfo = get_entra_conninfo(credential) - - # Build authenticated URL with fresh token - parsed_copy = parsed._replace( - netloc=f"{entra_conninfo['user']}:{entra_conninfo['password']}@{parsed.hostname}" + - (f":{parsed.port}" if parsed.port else "") - ) - auth_url = urlunparse(parsed_copy) - - # Create a temporary engine with the authenticated URL and get the DBAPI connection - temp_engine = create_engine(auth_url) - raw_conn = temp_engine.raw_connection() - # Return the underlying DBAPI connection, not the SQLAlchemy wrapper - return raw_conn.dbapi_connection - - # Create base URL without credentials for the engine - base_url = f"{parsed.scheme or 'postgresql'}://{parsed.hostname}" - if parsed.port: - base_url += f":{parsed.port}" - if parsed.path: - base_url += parsed.path - if parsed.query: - base_url += f"?{parsed.query}" + Args: + engine: The SQLAlchemy Engine to enable Entra authentication for + credential: Optional Azure credential. If None, uses DefaultAzureCredential + """ - # Create engine with custom connection factory - return create_engine(base_url, creator=connect_with_fresh_token, **kwargs) + @event.listens_for(engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + """Event handler that provides Entra credentials for each connection.""" + try: + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + entra_creds = get_entra_conninfo(credential) + + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] + + logger.debug(f"Provided Entra credentials for user: {entra_creds.get('user', 'unknown')}") + else: + logger.debug("User and password already present, skipping Entra authentication") + except Exception as e: + logger.error(f"Failed to get Entra credentials: {e}") + raise -def create_async_engine_with_entra( - url: str, - credential: Optional[AsyncTokenCredential] = None, - **kwargs: Any -) -> AsyncEngine: - """Creates an async SQLAlchemy Engine using Entra authentication for Azure PostgreSQL. - This function handles Azure Entra ID token acquisition and creates an async SQLAlchemy engine - that automatically refreshes tokens for each new connection. This solves the token - expiration issue by acquiring fresh tokens on each connection attempt. - - Parameters: - url (str): The database URL. Username and password will be replaced with Entra credentials. - credential (AsyncTokenCredential, optional): The async credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional engine creation parameters passed to create_async_engine() - - Returns: - AsyncEngine: An async SQLAlchemy engine configured with Entra authentication and automatic token refresh. - - Raises: - ImportError: If sqlalchemy.ext.asyncio is not available. - ValueError: If the provided credential is not a valid AsyncTokenCredential. - - Example: - engine = await create_async_engine_with_entra( - "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - ) +def enable_entra_authentication_async(engine: AsyncEngine, credential: Optional[TokenCredential] = None): """ - if create_async_engine is None: - raise ImportError( - "sqlalchemy.ext.asyncio is required for async engines. " - "Install with: pip install sqlalchemy[asyncio]" - ) - - credential = credential or AsyncDefaultAzureCredential() - if credential and not isinstance(credential, AsyncTokenCredential): - raise ValueError("credential must be an AsyncTokenCredential for async engines") - - # Parse the original URL to extract connection parameters - parsed = urlparse(url) - - async def async_connect_with_fresh_token() -> psycopg.AsyncConnection: - """Custom async connection factory that gets a fresh token each time.""" - logger.info("Creating new async connection with fresh Entra token") - - # Get fresh Entra authentication info for each connection - entra_conninfo = await get_entra_conninfo_async(credential) - - # For async, we need to return the raw async connection directly - # Import the appropriate async driver and create connection directly - if parsed.scheme == 'postgresql+psycopg' or parsed.scheme == 'postgresql': - return await psycopg.AsyncConnection.connect( - host=parsed.hostname, - port=parsed.port or 5432, - dbname=parsed.path.lstrip('/') if parsed.path else 'postgres', - user=entra_conninfo['user'], - password=entra_conninfo['password'] - ) - else: - raise ValueError(f"Unsupported async URL scheme: {parsed.scheme}. Use postgresql+psycopg or postgresql") - - # Create base URL without credentials for the engine - base_url = f"{parsed.scheme}://{parsed.hostname}" - if parsed.port: - base_url += f":{parsed.port}" - if parsed.path: - base_url += parsed.path - if parsed.query: - base_url += f"?{parsed.query}" + Enable Azure Entra ID authentication for an async SQLAlchemy engine. + + This function registers an event listener that automatically provides + Entra ID credentials for each database connection if they are not already set. - # Create async engine with custom connection factory - return create_async_engine(base_url, async_creator=async_connect_with_fresh_token, **kwargs) + Args: + engine: The async SQLAlchemy Engine to enable Entra authentication for + credential: Optional Azure credential. If None, uses DefaultAzureCredential + """ + + @event.listens_for(engine.sync_engine, "do_connect") + def provide_token_async(dialect, conn_rec, cargs, cparams): + """Event handler that provides Entra credentials for each async connection.""" + try: + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + # For async engines, we need to handle the async credential fetching + try: + # Try to get the current event loop + asyncio.get_running_loop() + # If we're in a running loop, we need to run the async function in a thread + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, get_entra_conninfo_async(credential)) + entra_creds = future.result() + except RuntimeError: + # No running event loop, we can use asyncio.run directly + # Set Windows event loop policy for compatibility if needed + if sys.platform.startswith('win'): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + entra_creds = asyncio.run(get_entra_conninfo_async(credential)) + + logger.debug("Successfully obtained async Entra credentials") + + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] + + logger.debug(f"Provided async Entra credentials for user: {entra_creds.get('user', 'unknown')}") + else: + logger.debug("User and password already present, skipping Entra authentication") + except Exception as e: + logger.error(f"Failed to get async Entra credentials: {e}") + raise \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py index bbc6dbf..c52a351 100644 --- a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py @@ -1,460 +1,46 @@ # Copyright (c) Microsoft. All rights reserved. -""" -Unit Tests for Azure PostgreSQL psycopg Entra ID Extension - -This test suite demonstrates and validates the Azure Entra ID authentication -functionality for PostgreSQL connections using psycopg2. These tests serve as -both validation and examples of how to use the extension. - -Test Categories: -1. JWT Token Decoding - Validates Azure token processing -2. Principal Name Parsing - Tests managed identity resource path parsing -3. Connection Info Generation - Tests core authentication logic -4. Sync/Async Connection Classes - Validates connection establishment - -Key Testing Patterns: -- Every synchronous test has an equivalent asynchronous test -- Comprehensive mocking to avoid external dependencies -- Edge case validation for robust error handling -- Clear test naming that describes expected behavior - -Usage: - # Run all tests - pytest test_psycopg_entra_id_extension.py - - # Run specific test class - pytest test_psycopg2_entra_id_extension.py::TestConnectWithEntra - - # Run with verbose output - pytest -v test_psycopg_entra_id_extension.py - -Dependencies: - pip install pytest pytest-asyncio - -For more information about Azure Entra ID authentication: -https://docs.microsoft.com/en-us/azure/postgresql/concepts-aad-authentication -""" - -import base64 -import json +import jwt import pytest -from unittest.mock import AsyncMock, Mock, patch -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential - -from azurepg_entra.psycopg2 import ( - connect_with_entra, - connect_with_entra_async, - decode_jwt, - get_entra_conninfo, - get_entra_conninfo_async, - parse_principal_name -) - -# Test Configuration -# These tests use mocking to avoid requiring actual Azure credentials or database connections. -# For integration testing with real Azure resources, see the samples/ directory. - +from unittest.mock import Mock, patch +from psycopg2.extensions import parse_dsn, make_dsn def create_test_token(payload): """Helper to create a test JWT token.""" - encoded_payload = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') - return f"header.{encoded_payload}.signature" - - -class TestDecodeJwt: - """ - Tests for JWT token decoding functionality. - - These tests validate that Azure Entra ID tokens are properly decoded - to extract user information. The extension supports various token formats - and claim structures that Azure may provide. - """ - - def test_decode_jwt_valid_token(self): - """Test decoding a valid JWT token with UPN (User Principal Name) claim.""" - # UPN is the most common claim for user identity in Azure AD tokens - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_with_padding(self): - """Test decoding JWT token that requires base64 padding.""" - payload = {"preferred_username": "testuser", "exp": 9999999999} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_minimal_payload(self): - """Test decoding JWT with minimal payload.""" - payload = {"unique_name": "user123"} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - + return jwt.encode(payload, key="", algorithm="none") -class TestParsePrincipalName: - """ - Tests for Azure resource path parsing. - - When using managed identities, Azure provides resource paths that need - to be parsed to extract the identity name for database authentication. - These tests ensure robust parsing of various path formats. - """ - def test_parse_principal_name_valid_user_assigned(self): - """Test parsing a valid user-assigned identity resource path.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - def test_parse_principal_name_empty_string(self): - """Test parsing an empty string returns None.""" - assert parse_principal_name("") is None - - def test_parse_principal_name_none(self): - """Test parsing None returns None.""" - assert parse_principal_name(None) is None - - def test_parse_principal_name_no_slash(self): - """Test parsing a string without slashes returns None.""" - assert parse_principal_name("no-slashes-here") is None - - def test_parse_principal_name_invalid_path(self): - """Test parsing an invalid resource path returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/SomeOther/resource/my-identity") - assert result is None - - def test_parse_principal_name_missing_identity_name(self): - """Test parsing a path without identity name returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/") - assert result is None - - def test_parse_principal_name_case_insensitive(self): - """Test parsing with different case variations.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/MICROSOFT.MANAGEDIDENTITY/USERASSIGNEDIDENTITIES/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - -class TestGetEntraConninfo: - """ - Tests for the core authentication logic (sync and async versions). - - These functions handle the complete flow of: - 1. Requesting Azure tokens with appropriate scopes - 2. Decoding tokens to extract user information - 3. Handling fallback scenarios for managed identities - 4. Returning connection parameters for psycopg - - Both synchronous and asynchronous patterns are tested to ensure - consistent behavior across different usage scenarios. - """ - - # Sync tests - def test_get_entra_conninfo_with_credential(self): - """Test getting connection info with sync credential and upn claim.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} +class TestSyncEntraConnection: + def test_dsn_processing_adds_entra_credentials(self): + """Test that SyncEntraConnection logic correctly merges Entra credentials into DSN.""" + payload = {"upn": "user@example.com"} token = create_test_token(payload) - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - def test_get_entra_conninfo_no_username_claims(self): - """Test error when no username claims are present.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - get_entra_conninfo(mock_credential) - - def test_get_entra_conninfo_username_priority(self): - """Test that upn takes priority over other username claims.""" - mock_credential = Mock(spec=TokenCredential) - # Azure tokens may contain multiple username claims - test priority order - payload = { - "upn": "upn@example.com", # Highest priority - "preferred_username": "preferred@example.com", # Second priority - "unique_name": "unique@example.com" # Fallback option - } - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result["user"] == "upn@example.com" # Should use highest priority claim - - def test_get_entra_conninfo_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username.""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] + with patch('azurepg_entra.core.get_entra_conninfo') as mock_get_creds: + mock_get_creds.return_value = {"user": "user@example.com", "password": token} - result = get_entra_conninfo(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - # Async tests - mirror of sync tests - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_with_credential(self): - """Test getting connection info with async credential and upn claim.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_no_username_claims(self): - """Test error when no username claims are present (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - await get_entra_conninfo_async(mock_credential) - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_username_priority(self): - """Test that upn takes priority over other username claims (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = { - "preferred_username": "preferred@example.com", - "unique_name": "unique@example.com", - "upn": "upn@example.com" - } - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result["user"] == "upn@example.com" - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username (async).""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] + from azurepg_entra.core import get_entra_conninfo - result = await get_entra_conninfo_async(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - -class TestConnectWithEntra: - """ - Tests for the connect_with_entra function. - - This function handles Azure Entra ID authentication for psycopg2 connections. - Tests validate: - - Proper credential handling and validation - - Fallback to standard authentication when credentials exist - - Integration with the underlying psycopg2.connect logic - """ - - def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } - - with patch('psycopg2.connect') as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection + # Test with existing DSN parameters + original_dsn = "host=localhost port=5432 dbname=testdb sslmode=require" + entra_creds = get_entra_conninfo(None) - result = connect_with_entra(**kwargs) + dsn_params = parse_dsn(original_dsn) if original_dsn else {} + dsn_params.update(entra_creds) + new_dsn = make_dsn(**dsn_params) - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "existing_user" - assert call_args["password"] == "existing_password" - assert result == mock_connection - - def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" - mock_credential = Mock(spec=TokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb" - } - - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_conninfo', return_value=expected_conninfo) as mock_get_conninfo: - with patch('psycopg2.connect') as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection - - result = connect_with_entra(credential=mock_credential, **kwargs) - - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "test@example.com" - assert call_args["password"] == "token123" - assert result == mock_connection - - def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost"} - - with pytest.raises(ValueError, match="credential must be a TokenCredential for synchronous connections"): - connect_with_entra(credential=invalid_credential, **kwargs) - - -class TestConnectWithEntraAsync: - """ - Tests for the connect_with_entra_async function. - - This function handles Azure Entra ID authentication for aiopg connections. - Tests mirror the synchronous tests to ensure consistent behavior between - sync/async usage patterns. - """ - - @pytest.mark.asyncio - async def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } - - with patch('aiopg.connect', new_callable=AsyncMock) as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection + mock_get_creds.assert_called_once_with(None) - result = await connect_with_entra_async(**kwargs) - - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "existing_user" - assert call_args["password"] == "existing_password" - assert result == mock_connection - - @pytest.mark.asyncio - async def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb" - } - - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_conninfo_async', return_value=expected_conninfo) as mock_get_conninfo: - with patch('aiopg.connect', new_callable=AsyncMock) as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection - - result = await connect_with_entra_async(credential=mock_credential, **kwargs) - - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "test@example.com" - assert call_args["password"] == "token123" - assert result == mock_connection - - @pytest.mark.asyncio - async def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost"} - - with pytest.raises(ValueError, match="credential must be an AsyncTokenCredential for async connections"): - await connect_with_entra_async(credential=invalid_credential, **kwargs) + # Original params preserved + assert "host=localhost" in new_dsn + assert "port=5432" in new_dsn + assert "dbname=testdb" in new_dsn + assert "sslmode=require" in new_dsn + # Entra creds added + assert "user=user@example.com" in new_dsn + assert f"password={token}" in new_dsn -# Example usage and test runner if __name__ == "__main__": - """ - Direct execution example for development and validation. - - This runs all tests with verbose output, which is helpful for: - - Understanding test coverage - - Debugging test failures - - Learning expected behavior patterns - - For CI/CD or automated testing, use pytest directly: - pytest test_psycopg_entra_id_extension.py -v --tb=short - """ import sys - - # Run with verbose output and short traceback format exit_code = pytest.main([__file__, "-v", "--tb=short"]) - - # Provide helpful guidance based on results - if exit_code == 0: - print("\n✅ All tests passed! The Azure Entra ID extension is working correctly.") - print("💡 Next steps: Try running the samples in samples/psycopg/getting_started/") - else: - print(f"\n❌ Some tests failed (exit code: {exit_code})") - print("💡 Check the test output above for details on any failures.") - print("💡 Ensure all dependencies are installed: pip install pytest pytest-asyncio") - - sys.exit(exit_code) - - -""" -Quick Start Guide for Understanding These Tests: - -1. Basic JWT Testing (TestDecodeJwt): - - Shows how Azure tokens are decoded to get user info - - Demonstrates different token claim formats - -2. Resource Path Parsing (TestParsePrincipalName): - - Tests managed identity resource path handling - - Shows how identity names are extracted for database auth - -3. Core Authentication Logic (TestGetEntraConninfo): - - Tests the main authentication flow - - Shows both sync and async patterns - - Demonstrates fallback mechanisms for edge cases - -4. Connection Functions (TestConnectWithEntra/TestConnectWithEntraAsync): - - Tests the customer-facing connection classes - - Shows how existing psycopg code can be easily adapted - - Validates credential handling and error cases - -Key Patterns to Notice: -- Every sync test has an async equivalent for consistency -- Mocking isolates tests from external dependencies -- Clear error messages help with debugging -- Tests serve as usage examples for developers - -For integration testing with real Azure resources, see the samples directory. -""" \ No newline at end of file + sys.exit(exit_code) \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py index 837141c..47f7777 100644 --- a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py @@ -1,42 +1,4 @@ # Copyright (c) Microsoft. All rights reserved. -""" -Unit Tests for Azure PostgreSQL psycopg Entra ID Extension - -This test suite demonstrates and validates the Azure Entra ID authentication -functionality for PostgreSQL connections using psycopg3. These tests serve as -both validation and examples of how to use the extension. - -Test Categories: -1. JWT Token Decoding - Validates Azure token processing -2. Principal Name Parsing - Tests managed identity resource path parsing -3. Connection Info Generation - Tests core authentication logic -4. Sync/Async Connection Classes - Validates connection establishment - -Key Testing Patterns: -- Every synchronous test has an equivalent asynchronous test -- Comprehensive mocking to avoid external dependencies -- Edge case validation for robust error handling -- Clear test naming that describes expected behavior - -Usage: - # Run all tests - pytest test_psycopg_entra_id_extension.py - - # Run specific test class - pytest test_psycopg_entra_id_extension.py::TestSyncEntraConnection - - # Run with verbose output - pytest -v test_psycopg_entra_id_extension.py - -Dependencies: - pip install pytest pytest-asyncio - -For more information about Azure Entra ID authentication: -https://docs.microsoft.com/en-us/azure/postgresql/concepts-aad-authentication -""" - -import base64 -import json import pytest from unittest.mock import AsyncMock, Mock, patch from azure.core.credentials import TokenCredential @@ -45,423 +7,92 @@ from azurepg_entra.psycopg3 import ( AsyncEntraConnection, SyncEntraConnection, - get_entra_conninfo, - get_entra_conninfo_async, - decode_jwt, - parse_principal_name, ) -# Test Configuration -# These tests use mocking to avoid requiring actual Azure credentials or database connections. -# For integration testing with real Azure resources, see the samples/ directory. - - -def create_test_token(payload): - """Helper to create a test JWT token.""" - encoded_payload = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') - return f"header.{encoded_payload}.signature" - - -class TestDecodeJwt: - """ - Tests for JWT token decoding functionality. - - These tests validate that Azure Entra ID tokens are properly decoded - to extract user information. The extension supports various token formats - and claim structures that Azure may provide. - """ - - def test_decode_jwt_valid_token(self): - """Test decoding a valid JWT token with UPN (User Principal Name) claim.""" - # UPN is the most common claim for user identity in Azure AD tokens - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_with_padding(self): - """Test decoding JWT token that requires base64 padding.""" - payload = {"preferred_username": "testuser", "exp": 9999999999} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_minimal_payload(self): - """Test decoding JWT with minimal payload.""" - payload = {"unique_name": "user123"} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - -class TestParsePrincipalName: - """ - Tests for Azure resource path parsing. - - When using managed identities, Azure provides resource paths that need - to be parsed to extract the identity name for database authentication. - These tests ensure robust parsing of various path formats. - """ - - def test_parse_principal_name_valid_user_assigned(self): - """Test parsing a valid user-assigned identity resource path.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - def test_parse_principal_name_empty_string(self): - """Test parsing an empty string returns None.""" - assert parse_principal_name("") is None - - def test_parse_principal_name_none(self): - """Test parsing None returns None.""" - assert parse_principal_name(None) is None - - def test_parse_principal_name_no_slash(self): - """Test parsing a string without slashes returns None.""" - assert parse_principal_name("no-slashes-here") is None - - def test_parse_principal_name_invalid_path(self): - """Test parsing an invalid resource path returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/SomeOther/resource/my-identity") - assert result is None - - def test_parse_principal_name_missing_identity_name(self): - """Test parsing a path without identity name returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/") - assert result is None - - def test_parse_principal_name_case_insensitive(self): - """Test parsing with different case variations.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/MICROSOFT.MANAGEDIDENTITY/USERASSIGNEDIDENTITIES/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - -class TestGetEntraConninfo: - """ - Tests for the core authentication logic (sync and async versions). - - These functions handle the complete flow of: - 1. Requesting Azure tokens with appropriate scopes - 2. Decoding tokens to extract user information - 3. Handling fallback scenarios for managed identities - 4. Returning connection parameters for psycopg - - Both synchronous and asynchronous patterns are tested to ensure - consistent behavior across different usage scenarios. - """ - - # Sync tests - def test_get_entra_conninfo_with_credential(self): - """Test getting connection info with sync credential and upn claim.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - def test_get_entra_conninfo_no_username_claims(self): - """Test error when no username claims are present.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - get_entra_conninfo(mock_credential) - - def test_get_entra_conninfo_username_priority(self): - """Test that upn takes priority over other username claims.""" - mock_credential = Mock(spec=TokenCredential) - # Azure tokens may contain multiple username claims - test priority order - payload = { - "upn": "upn@example.com", # Highest priority - "preferred_username": "preferred@example.com", # Second priority - "unique_name": "unique@example.com" # Fallback option - } - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result["user"] == "upn@example.com" # Should use highest priority claim - - def test_get_entra_conninfo_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username.""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = get_entra_conninfo(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - # Async tests - mirror of sync tests - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_with_credential(self): - """Test getting connection info with async credential and upn claim.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_no_username_claims(self): - """Test error when no username claims are present (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - await get_entra_conninfo_async(mock_credential) - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_username_priority(self): - """Test that upn takes priority over other username claims (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = { - "upn": "upn@example.com", - "preferred_username": "preferred@example.com", - "unique_name": "unique@example.com" - } - token = create_test_token(payload) +class TestSyncConnection: + def test_connect_with_existing_credentials(self): + """Test that existing user/password credentials are used without fetching Entra credentials.""" + kwargs = {"host": "localhost", "user": "existing_user", "password": "existing_password"} - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result["user"] == "upn@example.com" - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username (async).""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = await get_entra_conninfo_async(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - -class TestSyncEntraConnection: - """ - Tests for the SyncEntraConnection class. - - This class extends psycopg's Connection to automatically handle - Azure Entra ID authentication. Tests validate: - - Proper credential handling and validation - - Fallback to standard authentication when credentials exist - - Integration with the underlying psycopg connection logic - """ - - def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } - - with patch('psycopg.Connection.connect') as mock_super_connect: + with patch('psycopg.Connection.connect') as mock_connect: mock_connection = Mock() - mock_super_connect.return_value = mock_connection + mock_connect.return_value = mock_connection result = SyncEntraConnection.connect(**kwargs) - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] + assert result == mock_connection + call_args = mock_connect.call_args[1] assert call_args["user"] == "existing_user" assert call_args["password"] == "existing_password" - assert "credential" not in call_args - assert result == mock_connection - def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" + def test_connect_with_entra_credential(self): + """Test that Entra credentials are fetched and used when no user/password provided.""" mock_credential = Mock(spec=TokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb", - "credential": mock_credential - } - - expected_conninfo = {"user": "test@example.com", "password": "token123"} + kwargs = {"host": "localhost", "credential": mock_credential} - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo', return_value=expected_conninfo) as mock_get_conninfo: - with patch('psycopg.Connection.connect') as mock_super_connect: + with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo', + return_value={"user": "test@example.com", "password": "token123"}): + with patch('psycopg.Connection.connect') as mock_connect: mock_connection = Mock() - mock_super_connect.return_value = mock_connection + mock_connect.return_value = mock_connection result = SyncEntraConnection.connect(**kwargs) - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] + assert result == mock_connection + call_args = mock_connect.call_args[1] assert call_args["user"] == "test@example.com" assert call_args["password"] == "token123" - assert "credential" not in call_args - assert result == mock_connection - - def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost", "credential": invalid_credential} - - with pytest.raises(ValueError, match="credential must be a TokenCredential for synchronous connections"): - SyncEntraConnection.connect(**kwargs) + def test_connect_invalid_credential_type_throws(self): + """Test that invalid credential type raises ValueError.""" + with pytest.raises(ValueError, match="credential must be a TokenCredential for sync connections"): + SyncEntraConnection.connect(host="localhost", credential="invalid") -class TestAsyncEntraConnection: - """ - Tests for the AsyncEntraConnection class. - - This class extends psycopg's AsyncConnection to automatically handle - Azure Entra ID authentication in async contexts. Tests mirror the - synchronous tests to ensure consistent behavior between sync/async - usage patterns. - """ +class TestAsyncConnection: @pytest.mark.asyncio - async def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } + async def test_connect_with_existing_credentials(self): + """Test that existing user/password credentials are used without fetching Entra credentials (async).""" + kwargs = {"host": "localhost", "user": "existing_user", "password": "existing_password"} - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_super_connect: + with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_connect: mock_connection = Mock() - mock_super_connect.return_value = mock_connection + mock_connect.return_value = mock_connection result = await AsyncEntraConnection.connect(**kwargs) - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] + assert result == mock_connection + call_args = mock_connect.call_args[1] assert call_args["user"] == "existing_user" assert call_args["password"] == "existing_password" - assert "credential" not in call_args - assert result == mock_connection @pytest.mark.asyncio - async def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" + async def test_connect_with_entra_credential(self): + """Test that Entra credentials are fetched and used when no user/password provided (async).""" mock_credential = AsyncMock(spec=AsyncTokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb", - "credential": mock_credential - } + kwargs = {"host": "localhost", "credential": mock_credential} - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo_async', return_value=expected_conninfo) as mock_get_conninfo: - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_super_connect: + with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo_async', + return_value={"user": "test@example.com", "password": "token123"}): + with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_connect: mock_connection = Mock() - mock_super_connect.return_value = mock_connection + mock_connect.return_value = mock_connection result = await AsyncEntraConnection.connect(**kwargs) - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] + assert result == mock_connection + call_args = mock_connect.call_args[1] assert call_args["user"] == "test@example.com" assert call_args["password"] == "token123" - assert "credential" not in call_args - assert result == mock_connection @pytest.mark.asyncio - async def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost", "credential": invalid_credential} - + async def test_connect_invalid_credential_type_throws(self): + """Test that invalid credential type raises ValueError (async).""" with pytest.raises(ValueError, match="credential must be an AsyncTokenCredential for async connections"): - await AsyncEntraConnection.connect(**kwargs) + await AsyncEntraConnection.connect(host="localhost", credential="invalid") -# Example usage and test runner if __name__ == "__main__": - """ - Direct execution example for development and validation. - - This runs all tests with verbose output, which is helpful for: - - Understanding test coverage - - Debugging test failures - - Learning expected behavior patterns - - For CI/CD or automated testing, use pytest directly: - pytest test_psycopg_entra_id_extension.py -v --tb=short - """ import sys - - # Run with verbose output and short traceback format exit_code = pytest.main([__file__, "-v", "--tb=short"]) - - # Provide helpful guidance based on results - if exit_code == 0: - print("\n✅ All tests passed! The Azure Entra ID extension is working correctly.") - print("💡 Next steps: Try running the samples in samples/psycopg/getting_started/") - else: - print(f"\n❌ Some tests failed (exit code: {exit_code})") - print("💡 Check the test output above for details on any failures.") - print("💡 Ensure all dependencies are installed: pip install pytest pytest-asyncio") - - sys.exit(exit_code) - - -""" -Quick Start Guide for Understanding These Tests: - -1. Basic JWT Testing (TestDecodeJwt): - - Shows how Azure tokens are decoded to get user info - - Demonstrates different token claim formats - -2. Resource Path Parsing (TestParsePrincipalName): - - Tests managed identity resource path handling - - Shows how identity names are extracted for database auth - -3. Core Authentication Logic (TestGetEntraConninfo): - - Tests the main authentication flow - - Shows both sync and async patterns - - Demonstrates fallback mechanisms for edge cases - -4. Connection Classes (TestSyncEntraConnection/TestAsyncEntraConnection): - - Tests the customer-facing connection classes - - Shows how existing psycopg code can be easily adapted - - Validates credential handling and error cases - -Key Patterns to Notice: -- Every sync test has an async equivalent for consistency -- Mocking isolates tests from external dependencies -- Clear error messages help with debugging -- Tests serve as usage examples for developers - -For integration testing with real Azure resources, see the samples directory. -""" \ No newline at end of file + sys.exit(exit_code) \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py index 16964a8..4452f88 100644 --- a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py @@ -1,688 +1,151 @@ # Copyright (c) Microsoft. All rights reserved. -""" -Unit Tests for Azure PostgreSQL SQLAlchemy Entra ID Extension - -This test suite demonstrates and validates the Azure Entra ID authentication -functionality for PostgreSQL connections using SQLAlchemy. These tests serve as -both validation and examples of how to use the extension. - -Test Categories: -1. JWT Token Decoding - Validates Azure token processing -2. Principal Name Parsing - Tests managed identity resource path parsing -3. Connection Info Generation - Tests core authentication logic -4. Engine Creation Functions - Validates sync/async engine creation with Entra auth -5. Connection Factory Behavior - Tests custom connection factories for token refresh - -Key Testing Patterns: -- Every synchronous test has an equivalent asynchronous test -- Comprehensive mocking to avoid external dependencies -- Edge case validation for robust error handling -- Clear test naming that describes expected behavior - -Usage: - # Run all tests - pytest test_sqlalchemy_entra_id_extension.py - - # Run specific test class - pytest test_sqlalchemy_entra_id_extension.py::TestCreateEngineWithEntra - - # Run with verbose output - pytest -v test_sqlalchemy_entra_id_extension.py - -Dependencies: - pip install pytest pytest-asyncio sqlalchemy - -For more information about Azure Entra ID authentication: -https://docs.microsoft.com/en-us/azure/postgresql/concepts-aad-authentication -""" - -import base64 -import json import pytest -from unittest.mock import AsyncMock, Mock, patch, MagicMock -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential - -from azurepg_entra.sqlalchemy import ( - create_engine_with_entra, - create_async_engine_with_entra, - get_entra_conninfo, - get_entra_conninfo_async, - decode_jwt, - parse_principal_name, -) - -# Test Configuration -# These tests use mocking to avoid requiring actual Azure credentials or database connections. -# For integration testing with real Azure resources, see the samples/ directory. - - -def create_test_token(payload): - """Helper to create a test JWT token.""" - encoded_payload = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') - return f"header.{encoded_payload}.signature" - - -class TestDecodeJwt: - """ - Tests for JWT token decoding functionality. - - These tests validate that Azure Entra ID tokens are properly decoded - to extract user information. The extension supports various token formats - and claim structures that Azure may provide. - """ - - def test_decode_jwt_valid_token(self): - """Test decoding a valid JWT token with UPN (User Principal Name) claim.""" - # UPN is the most common claim for user identity in Azure AD tokens - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_with_padding(self): - """Test decoding JWT token that requires base64 padding.""" - payload = {"preferred_username": "testuser", "exp": 9999999999} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_minimal_payload(self): - """Test decoding JWT with minimal payload.""" - payload = {"unique_name": "user123"} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - -class TestParsePrincipalName: - """ - Tests for Azure resource path parsing. - - When using managed identities, Azure provides resource paths that need - to be parsed to extract the identity name for database authentication. - These tests ensure robust parsing of various path formats. - """ - - def test_parse_principal_name_valid_user_assigned(self): - """Test parsing a valid user-assigned identity resource path.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - def test_parse_principal_name_empty_string(self): - """Test parsing an empty string returns None.""" - assert parse_principal_name("") is None - - def test_parse_principal_name_none(self): - """Test parsing None returns None.""" - assert parse_principal_name(None) is None - - def test_parse_principal_name_no_slash(self): - """Test parsing a string without slashes returns None.""" - assert parse_principal_name("no-slashes-here") is None - - def test_parse_principal_name_invalid_path(self): - """Test parsing an invalid resource path returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/SomeOther/resource/my-identity") - assert result is None - - def test_parse_principal_name_missing_identity_name(self): - """Test parsing a path without identity name returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/") - assert result is None - - def test_parse_principal_name_case_insensitive(self): - """Test parsing with different case variations.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/MICROSOFT.MANAGEDIDENTITY/USERASSIGNEDIDENTITIES/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - -class TestGetEntraConninfo: - """ - Tests for the core authentication logic (sync and async versions). - - These functions handle the complete flow of: - 1. Requesting Azure tokens with appropriate scopes - 2. Decoding tokens to extract user information - 3. Handling fallback scenarios for managed identities - 4. Returning connection parameters for SQLAlchemy - - Both synchronous and asynchronous patterns are tested to ensure - consistent behavior across different usage scenarios. - """ - - # Sync tests - def test_get_entra_conninfo_with_credential(self): - """Test getting connection info with sync credential and upn claim.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - def test_get_entra_conninfo_no_username_claims(self): - """Test error when no username claims are present.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - get_entra_conninfo(mock_credential) - - def test_get_entra_conninfo_username_priority(self): - """Test that upn takes priority over other username claims.""" - mock_credential = Mock(spec=TokenCredential) - # Azure tokens may contain multiple username claims - test priority order - payload = { - "upn": "upn@example.com", # Highest priority - "preferred_username": "preferred@example.com", # Second priority - "unique_name": "unique@example.com" # Fallback option - } - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result["user"] == "upn@example.com" # Should use highest priority claim - - def test_get_entra_conninfo_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username.""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = get_entra_conninfo(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - # Async tests - mirror of sync tests - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_with_credential(self): - """Test getting connection info with async credential and upn claim.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_no_username_claims(self): - """Test error when no username claims are present (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - await get_entra_conninfo_async(mock_credential) - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_username_priority(self): - """Test that upn takes priority over other username claims (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = { - "upn": "upn@example.com", - "preferred_username": "preferred@example.com", - "unique_name": "unique@example.com" - } - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result["user"] == "upn@example.com" - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username (async).""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = await get_entra_conninfo_async(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - -class TestCreateEngineWithEntra: - """ - Tests for the create_engine_with_entra function. - - This function creates a synchronous SQLAlchemy engine with Entra authentication. - Tests validate: - - Proper engine creation with custom connection factory - - URL parsing and reconstruction - - Credential handling and validation - - Integration with SQLAlchemy's engine creation process - """ - - def test_create_engine_basic_url(self): - """Test engine creation with basic PostgreSQL URL.""" - mock_credential = Mock(spec=TokenCredential) - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra(url, credential=mock_credential) - - # Verify create_engine was called with base URL and creator function - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - - # Check that base URL doesn't contain credentials - assert "myserver.postgres.database.azure.com" in call_args[0][0] - assert "creator" in call_args[1] - assert callable(call_args[1]["creator"]) - assert result == mock_engine - - def test_create_engine_with_psycopg_scheme(self): - """Test engine creation with psycopg+ scheme.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com:5432/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra(url) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - - # Verify scheme is preserved in base URL - assert call_args[0][0].startswith("postgresql+psycopg://") - assert result == mock_engine - - def test_create_engine_with_query_parameters(self): - """Test engine creation preserves query parameters.""" - url = "postgresql://myserver.postgres.database.azure.com/mydatabase?sslmode=require&connect_timeout=30" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra(url) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - - # Verify query parameters are preserved - assert "sslmode=require" in call_args[0][0] - assert "connect_timeout=30" in call_args[0][0] - assert result == mock_engine - - def test_create_engine_invalid_credential_type(self): - """Test engine creation with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - - with pytest.raises(ValueError, match="credential must be a TokenCredential for synchronous engines"): - create_engine_with_entra(url, credential=invalid_credential) - - def test_create_engine_connection_factory_behavior(self): - """Test that the custom connection factory works correctly.""" - mock_credential = Mock(spec=TokenCredential) - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - # Mock the connection factory components - mock_temp_engine = Mock() - mock_raw_conn = Mock() - mock_dbapi_conn = Mock() - mock_raw_conn.dbapi_connection = mock_dbapi_conn - mock_temp_engine.raw_connection.return_value = mock_raw_conn - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo', return_value=expected_conninfo): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - # Mock the main create_engine call - mock_main_engine = Mock() - # Mock the temp engine creation inside connection factory - mock_create_engine.side_effect = [mock_main_engine, mock_temp_engine] +from unittest.mock import Mock, patch + +class TestEnableEntraAuthentication: + def test_sync_authentication_function_registration(self): + """Test that enable_entra_authentication registers event listener successfully.""" + mock_engine = Mock() + + with patch('sqlalchemy.event.listens_for') as mock_event_listener: + from azurepg_entra.sqlalchemy import enable_entra_authentication + enable_entra_authentication(mock_engine) + + # Verify event listener was registered with correct parameters + mock_event_listener.assert_called_once_with(mock_engine, "do_connect") + + def test_provide_token_method(self): + """Test the provide_token event handler method directly.""" + mock_engine = Mock() + + # Capture the event handler function + captured_handler = None + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + return decorator + + with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): + with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo') as mock_get_creds: + mock_get_creds.return_value = {"user": "test@example.com", "password": "test_token"} - result = create_engine_with_entra(url, credential=mock_credential) - assert result == mock_main_engine + from azurepg_entra.sqlalchemy import enable_entra_authentication + enable_entra_authentication(mock_engine) - # Verify main engine creation - assert mock_create_engine.call_count >= 1 - main_call_args = mock_create_engine.call_args_list[0] + # Test the captured handler directly + mock_cparams = {} + captured_handler(None, None, None, mock_cparams) - # Extract and test the connection factory function - creator_func = main_call_args[1]["creator"] - assert callable(creator_func) + # Verify credentials were added + mock_get_creds.assert_called_once_with(None) + assert mock_cparams["user"] == "test@example.com" + assert mock_cparams["password"] == "test_token" + + def test_provide_token_skips_existing_credentials(self): + """Test that provide_token skips when credentials already exist.""" + mock_engine = Mock() + + # Capture the event handler function + captured_handler = None + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + return decorator + + with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): + with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo') as mock_get_creds: + from azurepg_entra.sqlalchemy import enable_entra_authentication + enable_entra_authentication(mock_engine) - # Test the connection factory function (still within the outer patch context) - conn_result = creator_func() - assert conn_result == mock_dbapi_conn - - def test_create_engine_with_kwargs(self): - """Test that additional kwargs are passed through to create_engine.""" - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra( - url, - pool_size=10, - max_overflow=20, - echo=True - ) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args[1] - - # Verify additional kwargs are passed through - assert call_args["pool_size"] == 10 - assert call_args["max_overflow"] == 20 - assert call_args["echo"] == True - assert result == mock_engine - - -class TestCreateAsyncEngineWithEntra: - """ - Tests for the create_async_engine_with_entra function. - - This function creates an asynchronous SQLAlchemy engine with Entra authentication. - Tests validate: - - Proper async engine creation with custom connection factory - - URL parsing and psycopg3 async connection handling - - Credential handling for async operations - - Integration with SQLAlchemy's async engine creation process - """ - - def test_create_async_engine_import_error(self): - """Test error when sqlalchemy.ext.asyncio is not available.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - # Mock the absence of create_async_engine - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine', None): - with pytest.raises(ImportError, match="sqlalchemy.ext.asyncio is required for async engines"): - create_async_engine_with_entra(url) - - def test_create_async_engine_basic_url(self): - """Test async engine creation with basic PostgreSQL URL.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - result = create_async_engine_with_entra(url, credential=mock_credential) - - # Verify create_async_engine was called with base URL and async_creator function - mock_create_async_engine.assert_called_once() - call_args = mock_create_async_engine.call_args - - # Check that base URL doesn't contain credentials - assert "myserver.postgres.database.azure.com" in call_args[0][0] - assert "async_creator" in call_args[1] - assert callable(call_args[1]["async_creator"]) - assert result == mock_engine - - def test_create_async_engine_with_port_and_query(self): - """Test async engine creation preserves port and query parameters.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com:5432/mydatabase?sslmode=require" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - result = create_async_engine_with_entra(url) - - mock_create_async_engine.assert_called_once() - call_args = mock_create_async_engine.call_args - - # Verify port and query parameters are preserved - assert ":5432" in call_args[0][0] - assert "sslmode=require" in call_args[0][0] - assert result == mock_engine - - def test_create_async_engine_invalid_credential_type(self): - """Test async engine creation with invalid credential type raises error.""" - invalid_credential = "not_an_async_credential" - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - with pytest.raises(ValueError, match="credential must be an AsyncTokenCredential for async engines"): - create_async_engine_with_entra(url, credential=invalid_credential) - - @pytest.mark.asyncio - async def test_create_async_engine_connection_factory_behavior(self): - """Test that the custom async connection factory works correctly.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - # Mock psycopg AsyncConnection - mock_async_conn = Mock() - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo_async', return_value=expected_conninfo): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock, return_value=mock_async_conn) as mock_psycopg_connect: - mock_main_engine = Mock() - mock_create_async_engine.return_value = mock_main_engine - - result = create_async_engine_with_entra(url, credential=mock_credential) - - # Verify main engine creation - mock_create_async_engine.assert_called_once() - main_call_args = mock_create_async_engine.call_args - - # Extract and test the async connection factory function - async_creator_func = main_call_args[1]["async_creator"] - assert callable(async_creator_func) + # Test with existing credentials + mock_cparams = {"user": "existing@example.com", "password": "existing_password"} + captured_handler(None, None, None, mock_cparams) + + # Verify get_entra_conninfo was not called + mock_get_creds.assert_not_called() + assert mock_cparams["user"] == "existing@example.com" + assert mock_cparams["password"] == "existing_password" + + def test_async_authentication_function_registration(self): + """Test that enable_entra_authentication_async registers event listener successfully.""" + mock_async_engine = Mock() + mock_sync_engine = Mock() + mock_async_engine.sync_engine = mock_sync_engine + + with patch('sqlalchemy.event.listens_for') as mock_event_listener: + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + enable_entra_authentication_async(mock_async_engine) + + # Verify event listener was registered on sync_engine + mock_event_listener.assert_called_once_with(mock_sync_engine, "do_connect") + + def test_provide_token_async_method(self): + """Test the provide_token_async event handler method directly.""" + mock_async_engine = Mock() + mock_sync_engine = Mock() + mock_async_engine.sync_engine = mock_sync_engine + + # Capture the event handler function + captured_handler = None + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + return decorator + + with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): + with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo_async') as mock_get_creds_async: + with patch('asyncio.run') as mock_asyncio_run: + mock_get_creds_async.return_value = {"user": "test@example.com", "password": "test_token"} + mock_asyncio_run.return_value = {"user": "test@example.com", "password": "test_token"} - # Test the async connection factory function - conn_result = await async_creator_func() + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + enable_entra_authentication_async(mock_async_engine) - # Verify psycopg.AsyncConnection.connect was called with correct parameters - mock_psycopg_connect.assert_called_once() - connect_call_args = mock_psycopg_connect.call_args[1] - assert connect_call_args["host"] == "myserver.postgres.database.azure.com" - assert connect_call_args["port"] == 5432 - assert connect_call_args["dbname"] == "mydatabase" - assert connect_call_args["user"] == "test@example.com" - assert connect_call_args["password"] == "token123" + # Test the captured handler directly + mock_cparams = {} + captured_handler(None, None, None, mock_cparams) - assert conn_result == mock_async_conn - assert result == mock_main_engine - - def test_create_async_engine_unsupported_scheme(self): - """Test async engine creation with unsupported URL scheme.""" - url = "postgresql+asyncpg://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - # Create the engine (this should succeed) - result = create_async_engine_with_entra(url) - - # Extract the async_creator function - call_args = mock_create_async_engine.call_args - async_creator_func = call_args[1]["async_creator"] - - # Test that calling the async_creator with unsupported scheme raises error - with pytest.raises(ValueError, match="Unsupported async URL scheme: postgresql\\+asyncpg"): - # We need to create an async context to test this - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(async_creator_func()) - finally: - loop.close() - - def test_create_async_engine_with_kwargs(self): - """Test that additional kwargs are passed through to create_async_engine.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - result = create_async_engine_with_entra( - url, - pool_size=15, - max_overflow=25, - echo=True - ) - - mock_create_async_engine.assert_called_once() - call_args = mock_create_async_engine.call_args[1] - - # Verify additional kwargs are passed through - assert call_args["pool_size"] == 15 - assert call_args["max_overflow"] == 25 - assert call_args["echo"] == True - assert result == mock_engine - - -class TestUrlParsing: - """ - Tests for URL parsing and reconstruction logic. - - These tests validate that the URL parsing and reconstruction logic - works correctly for various URL formats and edge cases. - """ - - def test_url_parsing_with_different_schemes(self): - """Test URL parsing works with different PostgreSQL schemes.""" - test_urls = [ - "postgresql://server.com/db", - "postgresql+psycopg://server.com/db", - "postgresql+psycopg2://server.com/db" - ] - - for url in test_urls: - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine + # Verify credentials were added (either through direct call or asyncio.run) + assert mock_cparams["user"] == "test@example.com" + assert mock_cparams["password"] == "test_token" + + def test_provide_token_async_skips_existing_credentials(self): + """Test that provide_token_async skips when credentials already exist.""" + mock_async_engine = Mock() + mock_sync_engine = Mock() + mock_async_engine.sync_engine = mock_sync_engine + + # Capture the event handler function + captured_handler = None + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + return decorator + + with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): + with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo') as mock_get_creds: + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + enable_entra_authentication_async(mock_async_engine) - create_engine_with_entra(url) + # Test with existing credentials + mock_cparams = {"user": "existing@example.com", "password": "existing_password"} + captured_handler(None, None, None, mock_cparams) - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - # Verify the scheme is preserved in the base URL - assert call_args[0][0].startswith(url.split('://')[0] + '://') - - def test_url_parsing_with_complex_parameters(self): - """Test URL parsing with complex query parameters and paths.""" - url = "postgresql://server.com:5432/complex_db_name?sslmode=require&application_name=test%20app&connect_timeout=30" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - create_engine_with_entra(url) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - base_url = call_args[0][0] - - # Verify all components are preserved - assert "server.com:5432" in base_url - assert "/complex_db_name" in base_url - assert "sslmode=require" in base_url - assert "application_name=test%20app" in base_url - assert "connect_timeout=30" in base_url + # Verify get_entra_conninfo was not called + mock_get_creds.assert_not_called() + assert mock_cparams["user"] == "existing@example.com" + assert mock_cparams["password"] == "existing_password" -# Example usage and test runner if __name__ == "__main__": - """ - Direct execution example for development and validation. - - This runs all tests with verbose output, which is helpful for: - - Understanding test coverage - - Debugging test failures - - Learning expected behavior patterns - - For CI/CD or automated testing, use pytest directly: - pytest test_sqlalchemy_entra_id_extension.py -v --tb=short - """ import sys - - # Run with verbose output and short traceback format exit_code = pytest.main([__file__, "-v", "--tb=short"]) - - # Provide helpful guidance based on results - if exit_code == 0: - print("\n✅ All tests passed! The SQLAlchemy Azure Entra ID extension is working correctly.") - print("💡 Next steps: Try running the samples in samples/sqlalchemy/getting_started/") - else: - print(f"\n❌ Some tests failed (exit code: {exit_code})") - print("💡 Check the test output above for details on any failures.") - print("💡 Ensure all dependencies are installed: pip install pytest pytest-asyncio sqlalchemy") - - sys.exit(exit_code) - - -""" -Quick Start Guide for Understanding These Tests: - -1. Basic JWT Testing (TestDecodeJwt): - - Shows how Azure tokens are decoded to get user info - - Demonstrates different token claim formats - -2. Resource Path Parsing (TestParsePrincipalName): - - Tests managed identity resource path handling - - Shows how identity names are extracted for database auth - -3. Core Authentication Logic (TestGetEntraConninfo): - - Tests the main authentication flow - - Shows both sync and async patterns - - Demonstrates fallback mechanisms for edge cases - -4. Engine Creation Functions (TestCreateEngineWithEntra/TestCreateAsyncEngineWithEntra): - - Tests the main SQLAlchemy engine creation functions - - Shows how custom connection factories work - - Validates URL parsing and reconstruction - - Tests both sync and async engine patterns - -5. URL Parsing (TestUrlParsing): - - Tests URL parsing and reconstruction logic - - Shows how different schemes and parameters are handled - - Validates complex URL scenarios - -Key Patterns to Notice: -- Every sync test has an async equivalent for consistency -- Mocking isolates tests from external dependencies -- Custom connection factories are thoroughly tested -- URL parsing handles various PostgreSQL driver schemes -- Clear error messages help with debugging -- Tests serve as usage examples for developers - -For integration testing with real Azure resources, see the samples directory. - -SQLAlchemy-Specific Features Tested: -- Custom connection factory (creator parameter) -- Custom async connection factory (async_creator parameter) -- Engine configuration parameter pass-through -- URL scheme handling for different PostgreSQL drivers -- Integration with SQLAlchemy's connection pooling -- Proper DBAPI connection object handling -""" \ No newline at end of file + sys.exit(exit_code) \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/test_core_functionality.py b/python/tests/azure/data/postgresql/test_core_functionality.py new file mode 100644 index 0000000..38452bc --- /dev/null +++ b/python/tests/azure/data/postgresql/test_core_functionality.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft. All rights reserved. +import jwt +import pytest +from unittest.mock import AsyncMock, Mock, patch +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azurepg_entra.core import ( + decode_jwt, + parse_principal_name, + get_entra_conninfo, + get_entra_conninfo_async, +) + +def create_test_token(payload): + """Helper to create a test JWT token.""" + return jwt.encode(payload, key="", algorithm="none") + +class TestJwtParsing: + def test_decode_jwt_with_upn(self): + payload = {"upn": "user@example.com"} + token = create_test_token(payload) + result = decode_jwt(token) + assert result == payload + + def test_decode_jwt_with_preferred_username(self): + payload = {"preferred_username": "testuser@example.com"} + token = create_test_token(payload) + result = decode_jwt(token) + assert result == payload + + def test_decode_jwt_invalid_format_returns_none(self): + result = decode_jwt("invalid.token") + assert result is None + + def test_parse_principal_name_valid_path(self): + path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" + result = parse_principal_name(path) + assert result == "my-identity" + + def test_parse_principal_name_invalid_path_returns_none(self): + assert parse_principal_name("") is None + assert parse_principal_name(None) is None + assert parse_principal_name("/invalid/path") is None + + +class TestEntraAuthentication: + def test_get_entra_conninfo_with_upn(self): + mock_credential = Mock(spec=TokenCredential) + payload = {"upn": "user@example.com"} + token = create_test_token(payload) + + with patch('azurepg_entra.core.get_entra_token', return_value=token): + result = get_entra_conninfo(mock_credential) + assert result == {"user": "user@example.com", "password": token} + + def test_get_entra_conninfo_no_username_throws(self): + mock_credential = Mock(spec=TokenCredential) + payload = {"sub": "subject123"} + token = create_test_token(payload) + + with patch('azurepg_entra.core.get_entra_token', return_value=token): + with pytest.raises(ValueError, match="Could not determine username from token claims"): + get_entra_conninfo(mock_credential) + + @pytest.mark.asyncio + async def test_get_entra_conninfo_async_with_upn(self): + mock_credential = AsyncMock(spec=AsyncTokenCredential) + payload = {"upn": "user@example.com"} + token = create_test_token(payload) + + with patch('azurepg_entra.core.get_entra_token_async', return_value=token): + result = await get_entra_conninfo_async(mock_credential) + assert result == {"user": "user@example.com", "password": token} + + @pytest.mark.asyncio + async def test_get_entra_conninfo_async_no_username_throws(self): + mock_credential = AsyncMock(spec=AsyncTokenCredential) + payload = {"sub": "subject123"} + token = create_test_token(payload) + + with patch('azurepg_entra.core.get_entra_token_async', return_value=token): + with pytest.raises(ValueError, match="Could not determine username from token claims"): + await get_entra_conninfo_async(mock_credential) + + +if __name__ == "__main__": + import sys + exit_code = pytest.main([__file__, "-v", "--tb=short"]) + sys.exit(exit_code) \ No newline at end of file From 99ef432f4a1145cf17283762ae7935efa05f9777 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Sun, 5 Oct 2025 15:45:48 -0700 Subject: [PATCH 04/19] Update sqlalchemy tests --- .../test_sqlalchemy_entra_id_extension.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py index 4452f88..a4f1bf5 100644 --- a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py @@ -100,20 +100,18 @@ def decorator(func): with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo_async') as mock_get_creds_async: - with patch('asyncio.run') as mock_asyncio_run: - mock_get_creds_async.return_value = {"user": "test@example.com", "password": "test_token"} - mock_asyncio_run.return_value = {"user": "test@example.com", "password": "test_token"} - - from azurepg_entra.sqlalchemy import enable_entra_authentication_async - enable_entra_authentication_async(mock_async_engine) - - # Test the captured handler directly - mock_cparams = {} - captured_handler(None, None, None, mock_cparams) - - # Verify credentials were added (either through direct call or asyncio.run) - assert mock_cparams["user"] == "test@example.com" - assert mock_cparams["password"] == "test_token" + mock_get_creds_async.return_value = {"user": "test@example.com", "password": "test_token"} + + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + enable_entra_authentication_async(mock_async_engine) + + # Test the captured handler directly + mock_cparams = {} + captured_handler(None, None, None, mock_cparams) + + # Verify credentials were added (asyncio.run is always called for async credential fetching) + assert mock_cparams["user"] == "test@example.com" + assert mock_cparams["password"] == "test_token" def test_provide_token_async_skips_existing_credentials(self): """Test that provide_token_async skips when credentials already exist.""" @@ -131,7 +129,7 @@ def decorator(func): return decorator with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo') as mock_get_creds: + with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo_async') as mock_get_creds_async: from azurepg_entra.sqlalchemy import enable_entra_authentication_async enable_entra_authentication_async(mock_async_engine) @@ -139,10 +137,11 @@ def decorator(func): mock_cparams = {"user": "existing@example.com", "password": "existing_password"} captured_handler(None, None, None, mock_cparams) - # Verify get_entra_conninfo was not called - mock_get_creds.assert_not_called() + # Verify get_entra_conninfo_async was not called (credentials already exist) + mock_get_creds_async.assert_not_called() assert mock_cparams["user"] == "existing@example.com" assert mock_cparams["password"] == "existing_password" + if __name__ == "__main__": From ab1db2c11dbac857031454fc9a35d026a4bc586f Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Tue, 7 Oct 2025 22:08:30 -0700 Subject: [PATCH 05/19] Major refactoring including adding exception handling and renaming --- python/README.md | 152 ++++++++---------- python/pyproject.toml | 8 +- .../create_db_connection_psycopg2.py | 12 +- .../create_db_connection_psycopg3.py | 20 ++- .../create_db_connection_sqlalchemy.py | 20 ++- python/src/azurepg_entra/core.py | 102 +++++++----- python/src/azurepg_entra/errors.py | 23 +++ python/src/azurepg_entra/psycopg2/__init__.py | 12 +- .../psycopg2/entra_connection.py | 59 +++++++ .../psycopg2/psycopg2_entra_id_extension.py | 44 ----- python/src/azurepg_entra/psycopg3/__init__.py | 14 +- .../psycopg3/async_entra_connection.py | 55 +++++++ .../psycopg3/entra_connection.py | 54 +++++++ .../psycopg3/psycopg3_entra_id_extension.py | 81 ---------- .../src/azurepg_entra/sqlalchemy/__init__.py | 8 +- .../sqlalchemy/async_entra_connection.py | 44 +++++ .../sqlalchemy/entra_connection.py | 43 +++++ .../sqlalchemy_entra_id_extension.py | 102 ------------ .../test_psycopg2_entra_id_extension.py | 6 +- .../test_psycopg3_entra_id_extension.py | 23 +-- .../test_sqlalchemy_entra_id_extension.py | 17 +- .../postgresql/test_core_functionality.py | 25 ++- 22 files changed, 497 insertions(+), 427 deletions(-) create mode 100644 python/src/azurepg_entra/errors.py create mode 100644 python/src/azurepg_entra/psycopg2/entra_connection.py delete mode 100644 python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py create mode 100644 python/src/azurepg_entra/psycopg3/async_entra_connection.py create mode 100644 python/src/azurepg_entra/psycopg3/entra_connection.py delete mode 100644 python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py create mode 100644 python/src/azurepg_entra/sqlalchemy/async_entra_connection.py create mode 100644 python/src/azurepg_entra/sqlalchemy/entra_connection.py delete mode 100644 python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py diff --git a/python/README.md b/python/README.md index 5e1b369..6efe592 100644 --- a/python/README.md +++ b/python/README.md @@ -101,7 +101,7 @@ Choose the driver that best fits your project needs: - **psycopg3**: Modern PostgreSQL driver (recommended for new projects) - **psycopg2**: Legacy PostgreSQL driver (for existing projects) -- **SQLAlchemy**: High-level ORM/Core interface using psycopg3 backend +- **SQLAlchemy**: High-level ORM/Core interface --- @@ -109,54 +109,40 @@ Choose the driver that best fits your project needs: > **Note**: psycopg2 is in maintenance mode. For new projects, consider using psycopg3 instead. -The psycopg2 integration provides both synchronous (psycopg2) and asynchronous (aiopg) connection support with Azure Entra ID authentication. +The psycopg2 integration provides synchronous connection support with Azure Entra ID authentication through connection pooling. ### Installation ```bash pip install "azurepg-entra[psycopg2]" ``` -### Synchronous Connection (psycopg2) +### Connection Pooling (Recommended) ```python -from azurepg_entra.psycopg2 import connect_with_entra +from azurepg_entra.psycopg2 import EntraConnection from psycopg2 import pool +import os def main(): - # Direct connection - conn = connect_with_entra( + # Connection pooling with Entra authentication + connection_pool = pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, host="your-server.postgres.database.azure.com", - port=5432, - dbname="your_database" + database="your_database", + connection_factory=EntraConnection ) + # Get a connection from the pool + conn = connection_pool.getconn() + try: with conn.cursor() as cur: cur.execute("SELECT current_user, now()") user, time = cur.fetchone() print(f"Connected as: {user} at {time}") finally: - conn.close() - - # Connection pooling - def entra_connection_factory(*args, **kwargs): - return connect_with_entra( - host="your-server.postgres.database.azure.com", - port=5432, - dbname="your_database" - ) - - connection_pool = pool.ThreadedConnectionPool( - minconn=1, maxconn=5, - connection_factory=entra_connection_factory - ) - - conn = connection_pool.getconn() - try: - with conn.cursor() as cur: - cur.execute("SELECT current_user") - print(f"Pool connection as: {cur.fetchone()[0]}") - finally: + # Return connection to pool connection_pool.putconn(conn) connection_pool.closeall() @@ -164,30 +150,27 @@ if __name__ == "__main__": main() ``` -### Asynchronous Connection (aiopg) +### Direct Connection ```python -import asyncio -from azurepg_entra.psycopg2 import connect_with_entra_async +from azurepg_entra.psycopg2 import EntraConnection -async def main(): - # Direct async connection - conn = await connect_with_entra_async( - host="your-server.postgres.database.azure.com", - port=5432, - dbname="your_database" +def main(): + # Direct connection using DSN + conn = EntraConnection( + "postgresql://your-server.postgres.database.azure.com:5432/your_database" ) try: - async with conn.cursor() as cur: - await cur.execute("SELECT current_user, now()") - user, time = await cur.fetchone() - print(f"Async connected as: {user} at {time}") + with conn.cursor() as cur: + cur.execute("SELECT current_user, now()") + user, time = cur.fetchone() + print(f"Connected as: {user} at {time}") finally: conn.close() if __name__ == "__main__": - asyncio.run(main()) + main() ``` --- @@ -204,32 +187,28 @@ pip install "azurepg-entra[psycopg3]" ### Synchronous Connection ```python -from azurepg_entra.psycopg3 import SyncEntraConnection +from azurepg_entra.psycopg3 import EntraConnection from psycopg_pool import ConnectionPool def main(): - # Direct connection - with SyncEntraConnection.connect( - "postgresql://your-server.postgres.database.azure.com:5432/your_database" - ) as conn: - with conn.cursor() as cur: - cur.execute("SELECT current_user, now()") - user, time = cur.fetchone() - print(f"Connected as: {user} at {time}") - # Connection pooling (recommended for production) - with ConnectionPool( + pool = ConnectionPool( conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", - connection_class=SyncEntraConnection, + connection_class=EntraConnection, min_size=1, # keep at least 1 connection always open max_size=5, # allow up to 5 concurrent connections - max_waiting=10, # seconds to wait if pool is full - ) as pool: + open=False + ) + + pool.open() + try: with pool.connection() as conn: with conn.cursor() as cur: cur.execute("SELECT current_user, now()") user, time = cur.fetchone() - print(f"Pool connection as: {user} at {time}") + print(f"Connected as: {user} at {time}") + finally: + pool.close() if __name__ == "__main__": main() @@ -244,32 +223,28 @@ from azurepg_entra.psycopg3 import AsyncEntraConnection from psycopg_pool import AsyncConnectionPool async def main(): - # Direct async connection - async with await AsyncEntraConnection.connect( - "postgresql://your-server.postgres.database.azure.com:5432/your_database" - ) as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT current_user, now()") - user, time = await cur.fetchone() - print(f"Async connected as: {user} at {time}") - # Async connection pooling (recommended for production) - async with AsyncConnectionPool( + pool = AsyncConnectionPool( conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", connection_class=AsyncEntraConnection, min_size=1, # keep at least 1 connection always open max_size=5, # allow up to 5 concurrent connections - max_waiting=10, # seconds to wait if pool is full - ) as pool: + open=False + ) + + await pool.open() + try: async with pool.connection() as conn: async with conn.cursor() as cur: await cur.execute("SELECT current_user, now()") user, time = await cur.fetchone() - print(f"Pool connection as: {user} at {time}") + print(f"Async connected as: {user} at {time}") + finally: + await pool.close() if __name__ == "__main__": # Windows compatibility for async operations - if sys.platform == "win32": + if sys.platform.startswith('win'): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.run(main()) @@ -279,7 +254,7 @@ if __name__ == "__main__": ## SQLAlchemy Integration -SQLAlchemy integration uses psycopg3 as the backend driver with automatic Entra ID authentication. +SQLAlchemy integration uses psycopg3 as the backend driver with automatic Entra ID authentication through event listeners. ### Installation ```bash @@ -289,14 +264,15 @@ pip install "azurepg-entra[sqlalchemy]" ### Synchronous Engine ```python -from azurepg_entra.sqlalchemy import create_entra_engine -from sqlalchemy import text +from sqlalchemy import create_engine, text +from azurepg_entra.sqlalchemy import enable_entra_authentication def main(): - # Create synchronous engine with Entra ID authentication - engine = create_entra_engine( - "postgresql+psycopg://your-server.postgres.database.azure.com:5432/your_database" - ) + # Create synchronous engine + engine = create_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") + + # Enable Entra ID authentication + enable_entra_authentication(engine) # Core usage with engine.connect() as conn: @@ -324,15 +300,16 @@ if __name__ == "__main__": ```python import asyncio import sys -from azurepg_entra.sqlalchemy import create_async_entra_engine +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy import text -from sqlalchemy.ext.asyncio import async_sessionmaker +from azurepg_entra.sqlalchemy import enable_entra_authentication_async async def main(): - # Create asynchronous engine with Entra ID authentication - engine = await create_async_entra_engine( - "postgresql+psycopg://your-server.postgres.database.azure.com:5432/your_database" - ) + # Create asynchronous engine + engine = create_async_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") + + # Enable Entra ID authentication for async + enable_entra_authentication_async(engine) # Async Core usage async with engine.connect() as conn: @@ -341,6 +318,7 @@ async def main(): print(f"Async SQLAlchemy connected as: {user} at {time}") # Async ORM usage + from sqlalchemy.ext.asyncio import async_sessionmaker AsyncSession = async_sessionmaker(engine, expire_on_commit=False) async with AsyncSession() as session: @@ -352,7 +330,7 @@ async def main(): if __name__ == "__main__": # Windows compatibility for async operations - if sys.platform == "win32": + if sys.platform.startswith('win'): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.run(main()) @@ -380,8 +358,6 @@ The package automatically requests the correct OAuth2 scopes: - **⏰ Automatic expiration**: Tokens expire and are refreshed automatically - **🛡️ SSL enforcement**: All connections require SSL encryption - **🔑 Principle of least privilege**: Only database-specific scopes are requested -- **📋 Audit logging**: Authentication events are logged by Azure Database for PostgreSQL - --- ## Troubleshooting diff --git a/python/pyproject.toml b/python/pyproject.toml index 70bc2f1..1d2be75 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,14 +20,13 @@ classifiers = [ dependencies = [ "azure-identity>=1.13.0", "azure-core>=1.24.0", - "aiohttp>=3.8.0", - "PyJWT>=2.0.0" ] [project.optional-dependencies] # psycopg3 support psycopg3 = [ "psycopg[binary]>=3.1.0", + "aiohttp>=3.8.0" ] # psycopg2 support @@ -38,6 +37,7 @@ psycopg2 = [ # SQLAlchemy support sqlalchemy = [ "sqlalchemy>=2.0.0", + "aiohttp>=3.8.0" ] # Development dependencies @@ -51,11 +51,9 @@ dev = [ # All optional dependencies combined all = [ "psycopg[binary]>=3.1.0", - "psycopg-pool>=3.1.0", + "aiohttp>=3.8.0", "psycopg2-binary>=2.9.0", - "aiopg>=1.4.0", "sqlalchemy>=2.0.0", - "asyncpg>=0.28.0", "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "python-dotenv>=1.0.0", diff --git a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py index 1a94717..ba0c5ed 100644 --- a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py +++ b/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py @@ -5,7 +5,7 @@ from dotenv import load_dotenv import os from psycopg2 import pool -from azurepg_entra.psycopg2 import SyncEntraConnection +from azurepg_entra.psycopg2 import EntraConnection # Load environment variables from .env file load_dotenv() @@ -14,15 +14,17 @@ def main_sync(): try: - # We pass in the SyncEntraConnection class to enable Entra authentication for the - # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using - # the token itself (with the PostgreSQL scope) as the password. + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/docs/advanced.html#subclassing-connection connection_pool = pool.ThreadedConnectionPool( minconn=1, maxconn=5, host=SERVER, database=DATABASE, - connection_factory=SyncEntraConnection + connection_factory=EntraConnection ) # Get a connection from the pool diff --git a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py index 6445189..c5c0b80 100644 --- a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py +++ b/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py @@ -9,7 +9,7 @@ import asyncio import sys import os -from azurepg_entra.psycopg3 import SyncEntraConnection, AsyncEntraConnection +from azurepg_entra.psycopg3 import EntraConnection, AsyncEntraConnection # Load environment variables from .env file load_dotenv() @@ -20,15 +20,17 @@ def main_sync(): """Synchronous connection example using psycopg with Entra ID authentication.""" try: - # We pass in the SyncEntraConnection class to enable Entra authentication for the - # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using - # the token itself (with the PostgreSQL scope) as the password. + # We use the SyncEntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect pool = ConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, max_size=5, open=False, - connection_class=SyncEntraConnection + connection_class=EntraConnection ) pool.open() with pool, pool.connection() as conn, conn.cursor() as cur: @@ -49,9 +51,11 @@ async def main_async(): """Asynchronous connection example using psycopg with Entra ID authentication.""" try: - # We pass in the AsyncEntraConnection class to enable Entra authentication for the - # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using - # the token itself (with the PostgreSQL scope) as the password. + # We use the AsyncEntraConnection class to enable asynchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect pool = AsyncConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, diff --git a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py index 56ce94e..b9fd360 100644 --- a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py +++ b/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py @@ -24,10 +24,12 @@ def main_sync(): # Create a synchronous engine engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") - # We add an event listener to the engine to enable Entra authentication for the - # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using - # the token itself (with the PostgreSQL scope) as the password. This event listener is triggered - # whenever we get a NEW connection from the pool backing the engine. + # We add an event listener to the engine to enable synchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function enable_entra_authentication(engine) with engine.connect() as conn: @@ -52,10 +54,12 @@ async def main_async(): # Create an asynchronous engine engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") - # We add an event listener to the engine to enable Entra authentication for the - # PostgreSQL database by acquiring an Azure access token, extracting a username from the token, and using - # the token itself (with the PostgreSQL scope) as the password. This event listener is triggered - # whenever we get a NEW connection from the pool backing the engine. + # We add an event listener to the engine to enable asynchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function enable_entra_authentication_async(engine) async with engine.connect() as conn: diff --git a/python/src/azurepg_entra/core.py b/python/src/azurepg_entra/core.py index 626ce77..8ebfcc8 100644 --- a/python/src/azurepg_entra/core.py +++ b/python/src/azurepg_entra/core.py @@ -1,12 +1,14 @@ -import logging -import jwt +import base64 +import json from typing import Any, cast from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import ClientAuthenticationError from azure.identity import DefaultAzureCredential as DefaultAzureCredential from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential +from azure.identity import CredentialUnavailableError +from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, ScopePermissionError -logger = logging.getLogger(__name__) AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" AZURE_MANAGEMENT_SCOPE = "https://management.azure.com/.default" @@ -21,8 +23,6 @@ def get_entra_token(credential: TokenCredential | None, scope: str) -> str: Returns: str: The acquired authentication token to be used as the database password. """ - logger.info("Acquiring Entra token for postgres password") - credential = credential or DefaultAzureCredential() cred = credential.get_token(scope) return cred.token @@ -38,29 +38,30 @@ async def get_entra_token_async(credential: AsyncTokenCredential | None, scope: Returns: str: The acquired authentication token to be used as the database password. """ - logger.info("Acquiring Entra token for postgres password") - credential = credential or AsyncDefaultAzureCredential() async with credential: cred = await credential.get_token(scope) return cred.token -def decode_jwt(token: str) -> dict[str, Any] | None: +def decode_jwt(token: str) -> dict[str, Any]: """Decodes a JWT token to extract its payload claims. Parameters: token (str): The JWT token string in the standard three-part format. Returns: - dict | None: A dictionary containing the claims extracted from the token payload, - or None if the token is invalid. + dict[str, Any]: A dictionary containing the claims extracted from the token payload. + + Raises: + TokenValueError: If the token format is invalid or cannot be decoded. """ try: - # Decode without verification since we only need the payload claims - # Azure tokens are already validated by the credential provider - return cast(dict[str, Any], jwt.decode(token, options={"verify_signature": False})) - except Exception: - return None + payload = token.split(".")[1] + padding = "=" * (4 - len(payload) % 4) + decoded_payload = base64.urlsafe_b64decode(payload + padding) + return cast(dict[str, Any], json.loads(decoded_payload)) + except Exception as e: + raise TokenDecodeError("Invalid JWT token format") from e def parse_principal_name(xms_mirid: str) -> str | None: """Parses the principal name from an Azure resource path. @@ -91,23 +92,31 @@ def parse_principal_name(xms_mirid: str) -> str | None: def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: """Synchronously obtains connection information from Entra authentication for Azure PostgreSQL. + This function acquires an access token from Azure Entra ID and extracts the username + from the token claims. It tries multiple claim sources to determine the username. + Parameters: credential (TokenCredential or None): The credential used for token acquisition. - If None, the default Azure credentials are used. + If None, DefaultAzureCredential() is used to automatically discover credentials. Returns: - dict[str, str]: A dictionary with 'user' and 'password' keys containing the username and token. + dict[str, str]: A dictionary with 'user' and 'password' keys, where: + - 'user': The extracted username from token claims + - 'password': The Entra ID access token for database authentication Raises: - ValueError: If the username cannot be extracted from the token payload. + TokenDecodeError: If the JWT token cannot be decoded or is malformed. + UsernameExtractionError: If the username cannot be extracted from token claims. + ScopePermissionError: The token could not be acquired from the management scope, possibly due to insufficient permissions. """ credential = credential or DefaultAzureCredential() # Always get the DB-scope token for password db_token = get_entra_token(credential, AZURE_DB_FOR_POSTGRES_SCOPE) - db_claims = decode_jwt(db_token) - if not db_claims: - raise ValueError("Invalid DB token format") + try: + db_claims = decode_jwt(db_token) + except TokenDecodeError: + raise xms_mirid = db_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None @@ -118,10 +127,14 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: if not username: # Fall back to management scope ONLY to discover username - mgmt_token = get_entra_token(credential, AZURE_MANAGEMENT_SCOPE) - mgmt_claims = decode_jwt(mgmt_token) - if not mgmt_claims: - raise ValueError("Invalid management token format") + try: + mgmt_token = get_entra_token(credential, AZURE_MANAGEMENT_SCOPE) + except (CredentialUnavailableError, ClientAuthenticationError) as e: + raise ScopePermissionError("Failed to acquire token from management scope") from e + try: + mgmt_claims = decode_jwt(mgmt_token) + except TokenDecodeError: + raise xms_mirid = mgmt_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None @@ -131,7 +144,7 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: ) if not username: - raise ValueError( + raise UsernameExtractionError( "Could not determine username from token claims. " "Ensure the identity has the proper Azure AD attributes." ) @@ -141,22 +154,30 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> dict[str, str]: """Asynchronously obtains connection information from Entra authentication for Azure PostgreSQL. + This function acquires an access token from Azure Entra ID and extracts the username + from the token claims. It tries multiple claim sources to determine the username. + Parameters: credential (AsyncTokenCredential or None): The async credential used for token acquisition. - If None, the default Azure credentials are used. + If None, AsyncDefaultAzureCredential() is used to automatically discover credentials. Returns: - dict[str, str]: A dictionary with 'user' and 'password' keys containing the username and token. + dict[str, str]: A dictionary with 'user' and 'password' keys, where: + - 'user': The extracted username from token claims + - 'password': The Entra ID access token for database authentication Raises: - ValueError: If the username cannot be extracted from the token payload. + TokenDecodeError: If the JWT token cannot be decoded or is malformed. + UsernameExtractionError: If the username cannot be extracted from token claims. + ScopePermissionError: The token could not be acquired from the management scope, possibly due to insufficient permissions. """ credential = credential or AsyncDefaultAzureCredential() db_token = await get_entra_token_async(credential, AZURE_DB_FOR_POSTGRES_SCOPE) - db_claims = decode_jwt(db_token) - if not db_claims: - raise ValueError("Invalid DB token format") + try: + db_claims = decode_jwt(db_token) + except TokenDecodeError: + raise xms_mirid = db_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None @@ -166,10 +187,14 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d ) if not username: - mgmt_token = await get_entra_token_async(credential, AZURE_MANAGEMENT_SCOPE) - mgmt_claims = decode_jwt(mgmt_token) - if not mgmt_claims: - raise ValueError("Invalid management token format") + try: + mgmt_token = await get_entra_token_async(credential, AZURE_MANAGEMENT_SCOPE) + except (CredentialUnavailableError, ClientAuthenticationError) as e: + raise ScopePermissionError("Failed to acquire token from management scope") from e + try: + mgmt_claims = decode_jwt(mgmt_token) + except TokenDecodeError: + raise xms_mirid = mgmt_claims.get("xms_mirid") username = ( parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None @@ -179,6 +204,9 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d ) if not username: - raise ValueError("Could not determine username from token claims.") + raise UsernameExtractionError( + "Could not determine username from token claims. " + "Ensure the identity has the proper Azure AD attributes." + ) return {"user": username, "password": db_token} \ No newline at end of file diff --git a/python/src/azurepg_entra/errors.py b/python/src/azurepg_entra/errors.py new file mode 100644 index 0000000..ae87bdf --- /dev/null +++ b/python/src/azurepg_entra/errors.py @@ -0,0 +1,23 @@ +class EntraIdBaseError(Exception): + """Base class for all custom exceptions in the project.""" + pass + +class TokenDecodeError(EntraIdBaseError): + """Raised when a token value is invalid.""" + pass + +class UsernameExtractionError(EntraIdBaseError): + """Raised when username cannot be extracted from token.""" + pass + +class CredentialValueError(EntraIdBaseError): + """Raised when token credential is invalid.""" + pass + +class EntraConnectionValueError(EntraIdBaseError): + """Raised when Entra connection credentials are invalid.""" + pass + +class ScopePermissionError(EntraIdBaseError): + """Raised when the provided scope does not have sufficient permissions.""" + pass \ No newline at end of file diff --git a/python/src/azurepg_entra/psycopg2/__init__.py b/python/src/azurepg_entra/psycopg2/__init__.py index e30a0b9..ad20de5 100644 --- a/python/src/azurepg_entra/psycopg2/__init__.py +++ b/python/src/azurepg_entra/psycopg2/__init__.py @@ -12,28 +12,28 @@ - psycopg2-binary>=2.8.0 Classes: - SyncEntraConnection: Synchronous connection class with Entra ID authentication (psycopg2) + EntraConnection: Synchronous connection class with Entra ID authentication (psycopg2) Example usage: # Synchronous connection - from azurepg_entra.psycopg2 import SyncEntraConnection + from azurepg_entra.psycopg2 import EntraConnection connection_pool = pool.ThreadedConnectionPool( minconn=1, maxconn=5, host=SERVER, database=DATABASE, - connection_factory=SyncEntraConnection + connection_factory=EntraConnection ) """ try: - from .psycopg2_entra_id_extension import ( - SyncEntraConnection, + from .entra_connection import ( + EntraConnection, ) __all__ = [ - "SyncEntraConnection", + "EntraConnection", ] except ImportError as e: diff --git a/python/src/azurepg_entra/psycopg2/entra_connection.py b/python/src/azurepg_entra/psycopg2/entra_connection.py new file mode 100644 index 0000000..7a31606 --- /dev/null +++ b/python/src/azurepg_entra/psycopg2/entra_connection.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft. All rights reserved. +from psycopg2.extensions import connection, parse_dsn, make_dsn +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, CredentialValueError, ScopePermissionError +from azure.core.credentials import TokenCredential + +class EntraConnection(connection): + """Establishes a synchronous PostgreSQL connection using Entra authentication. + + This connection class automatically acquires Azure Entra ID credentials when user + or password are not provided in the DSN or connection parameters. Authentication + errors are printed to console for debugging purposes. + + Parameters: + dsn (str): PostgreSQL connection string (Data Source Name). + **kwargs: Additional keyword arguments including: + - credential (TokenCredential, optional): Azure credential for token acquisition. + If None, DefaultAzureCredential() is used. + - user (str, optional): Database username. If not provided, extracted from Entra token. + - password (str, optional): Database password. If not provided, uses Entra access token. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + def __init__(self, dsn, **kwargs): + # Extract current DSN params + dsn_params = parse_dsn(dsn) if dsn else {} + + credential = kwargs.pop("credential", None) + if credential and not isinstance(credential, (TokenCredential)): + raise CredentialValueError("credential must be a TokenCredential for sync connections") + + # Check if user and password are already provided + has_user = 'user' in dsn_params or 'user' in kwargs + has_password = 'password' in dsn_params or 'password' in kwargs + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + print(repr(e)) + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + + # Only update missing credentials + if not has_user and 'user' in entra_creds: + dsn_params['user'] = entra_creds['user'] + if not has_password and 'password' in entra_creds: + dsn_params['password'] = entra_creds['password'] + + # Update DSN params with any kwargs (kwargs take precedence) + dsn_params.update(kwargs) + + # Create new DSN with updated credentials + new_dsn = make_dsn(**dsn_params) + + # Call parent constructor with updated DSN only + super().__init__(new_dsn) \ No newline at end of file diff --git a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py deleted file mode 100644 index dab59df..0000000 --- a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -from psycopg2.extensions import connection, parse_dsn, make_dsn -from azurepg_entra.core import get_entra_conninfo - -# Define a custom connection class -class SyncEntraConnection(connection): - """Establishes a synchronous PostgreSQL connection using Entra authentication. - - The method checks for provided credentials. If the 'user' or 'password' are not set - in the DSN or keyword arguments, it acquires them from Entra via the provided or default credential. - - Parameters: - dsn: PostgreSQL connection string. - **kwargs: Keyword arguments including optional 'credential', and optionally 'user' and 'password'. - - Raises: - ValueError: If the provided credential is not a valid TokenCredential. - """ - def __init__(self, dsn, **kwargs): - # Extract current DSN params - dsn_params = parse_dsn(dsn) if dsn else {} - - # Check if user and password are already provided - has_user = 'user' in dsn_params or 'user' in kwargs - has_password = 'password' in dsn_params or 'password' in kwargs - - # Only get Entra credentials if user or password is missing - if not has_user or not has_password: - entra_creds = get_entra_conninfo(None) - - # Only update missing credentials - if not has_user and 'user' in entra_creds: - dsn_params['user'] = entra_creds['user'] - if not has_password and 'password' in entra_creds: - dsn_params['password'] = entra_creds['password'] - - # Update DSN params with any kwargs (kwargs take precedence) - dsn_params.update(kwargs) - - # Create new DSN with updated credentials - new_dsn = make_dsn(**dsn_params) - - # Call parent constructor with updated DSN only - super().__init__(new_dsn) \ No newline at end of file diff --git a/python/src/azurepg_entra/psycopg3/__init__.py b/python/src/azurepg_entra/psycopg3/__init__.py index 81a7a39..831b089 100644 --- a/python/src/azurepg_entra/psycopg3/__init__.py +++ b/python/src/azurepg_entra/psycopg3/__init__.py @@ -12,17 +12,17 @@ - psycopg[binary]>=3.1.0 Classes: - SyncEntraConnection: Synchronous connection class with Entra ID authentication + EntraConnection: Synchronous connection class with Entra ID authentication AsyncEntraConnection: Asynchronous connection class with Entra ID authentication Example usage: - from azurepg_entra.psycopg3 import SyncEntraConnection, AsyncEntraConnection + from azurepg_entra.psycopg3 import EntraConnection, AsyncEntraConnection from psycopg_pool import ConnectionPool, AsyncConnectionPool # Synchronous usage pool = ConnectionPool( conninfo="postgresql://myserver:5432/mydb", - connection_class=SyncEntraConnection + connection_class=EntraConnection ) # Asynchronous usage @@ -33,12 +33,10 @@ """ try: - from .psycopg3_entra_id_extension import ( - SyncEntraConnection, - AsyncEntraConnection - ) + from .entra_connection import EntraConnection + from .async_entra_connection import AsyncEntraConnection __all__ = [ - "SyncEntraConnection", + "EntraConnection", "AsyncEntraConnection" ] except ImportError as e: diff --git a/python/src/azurepg_entra/psycopg3/async_entra_connection.py b/python/src/azurepg_entra/psycopg3/async_entra_connection.py new file mode 100644 index 0000000..171e315 --- /dev/null +++ b/python/src/azurepg_entra/psycopg3/async_entra_connection.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft. All rights reserved. + +from psycopg import AsyncConnection +from azure.core.credentials_async import AsyncTokenCredential +from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, CredentialValueError, ScopePermissionError +from azurepg_entra.core import get_entra_conninfo_async + +from typing import Any +try: + from typing import Self +except ImportError: + from typing_extensions import Self # fallback for older Python + +class AsyncEntraConnection(AsyncConnection[tuple[Any, ...]]): + """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" + + @classmethod + async def connect(cls, *args: Any, **kwargs: Any) -> Self: + """Establishes an asynchronous PostgreSQL connection using Entra authentication. + + This method automatically acquires Azure Entra ID credentials when user or password + are not provided in the connection parameters. Authentication errors are printed to + console for debugging purposes. + + Parameters: + *args: Positional arguments to be forwarded to the parent connection method. + **kwargs: Keyword arguments including: + - credential (AsyncTokenCredential, optional): Async Azure credential for token acquisition. + - user (str, optional): Database username. If not provided, extracted from Entra token. + - password (str, optional): Database password. If not provided, uses Entra access token. + + Returns: + AsyncEntraConnection: An open asynchronous connection to the PostgreSQL database. + + Raises: + CredentialValueError: If the provided credential is not a valid AsyncTokenCredential. + EntraConnectionValueError: If Entra connection credentials are invalid. + """ + credential = kwargs.pop("credential", None) + if credential and not isinstance(credential, (AsyncTokenCredential)): + raise CredentialValueError("credential must be an AsyncTokenCredential for async connections") + + # Check if we need to acquire Entra authentication info + if not kwargs.get("user") or not kwargs.get("password"): + try: + entra_conninfo = await get_entra_conninfo_async(credential) + except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + print(repr(e)) + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Always use the token password when Entra authentication is needed + kwargs["password"] = entra_conninfo["password"] + if not kwargs.get("user"): + # If user isn't already set, use the username from the token + kwargs["user"] = entra_conninfo["user"] + return await super().connect(*args, **kwargs) \ No newline at end of file diff --git a/python/src/azurepg_entra/psycopg3/entra_connection.py b/python/src/azurepg_entra/psycopg3/entra_connection.py new file mode 100644 index 0000000..53efb8e --- /dev/null +++ b/python/src/azurepg_entra/psycopg3/entra_connection.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any +try: + from typing import Self +except ImportError: + from typing_extensions import Self # fallback for older Python +from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, CredentialValueError, ScopePermissionError +from azure.core.credentials import TokenCredential +from azurepg_entra.core import get_entra_conninfo +from psycopg import Connection + +class EntraConnection(Connection[tuple[Any, ...]]): + """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" + + @classmethod + def connect(cls, *args: Any, **kwargs: Any) -> Self: + """Establishes a synchronous PostgreSQL connection using Entra authentication. + + This method automatically acquires Azure Entra ID credentials when user or password + are not provided in the connection parameters. If authentication fails, the original + exception is re-raised to the caller. + + Parameters: + *args: Positional arguments to be forwarded to the parent connection method. + **kwargs: Keyword arguments including: + - credential (TokenCredential, optional): Azure credential for token acquisition. + - user (str, optional): Database username. If not provided, extracted from Entra token. + - password (str, optional): Database password. If not provided, uses Entra access token. + + Returns: + EntraConnection: An open synchronous connection to the PostgreSQL database. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + credential = kwargs.pop("credential", None) + if credential and not isinstance(credential, (TokenCredential)): + raise CredentialValueError("credential must be a TokenCredential for sync connections") + + # Check if we need to acquire Entra authentication info + if not kwargs.get("user") or not kwargs.get("password"): + try: + entra_conninfo = get_entra_conninfo(credential) + except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + print(repr(e)) + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Always use the token password when Entra authentication is needed + kwargs["password"] = entra_conninfo["password"] + if not kwargs.get("user"): + # If user isn't already set, use the username from the token + kwargs["user"] = entra_conninfo["user"] + return super().connect(*args, **kwargs) \ No newline at end of file diff --git a/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py b/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py deleted file mode 100644 index 2a7ade4..0000000 --- a/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from typing import Any -try: - from typing import Self -except ImportError: - from typing_extensions import Self # fallback for older Python - -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential -from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async -from psycopg import AsyncConnection, Connection - -class SyncEntraConnection(Connection[tuple[Any, ...]]): - """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" - - @classmethod - def connect(cls, *args: Any, **kwargs: Any) -> Self: - """Establishes a synchronous PostgreSQL connection using Entra authentication. - - The method checks for provided credentials. If the 'user' or 'password' are not set - in the keyword arguments, it acquires them from Entra via the provided or default credential. - - Parameters: - *args: Positional arguments to be forwarded to the parent connection method. - **kwargs: Keyword arguments including optional 'credential', and optionally 'user' and 'password'. - - Returns: - SyncEntraConnection: An open synchronous connection to the PostgreSQL database. - - Raises: - ValueError: If the provided credential is not a valid TokenCredential. - """ - credential = kwargs.pop("credential", None) - if credential and not isinstance(credential, (TokenCredential)): - raise ValueError("credential must be a TokenCredential for sync connections") - - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = get_entra_conninfo(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] - return super().connect(*args, **kwargs) - - -class AsyncEntraConnection(AsyncConnection[tuple[Any, ...]]): - """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" - - @classmethod - async def connect(cls, *args: Any, **kwargs: Any) -> Self: - """Establishes an asynchronous PostgreSQL connection using Entra authentication. - - The method checks for provided credentials. If the 'user' or 'password' are not set - in the keyword arguments, it acquires them from Entra via the provided or default credential. - - Parameters: - *args: Positional arguments to be forwarded to the parent connection method. - **kwargs: Keyword arguments including optional 'credential', and optionally 'user' and 'password'. - - Returns: - AsyncEntraConnection: An open asynchronous connection to the PostgreSQL database. - - Raises: - ValueError: If the provided credential is not a valid AsyncTokenCredential. - """ - credential = kwargs.pop("credential", None) - if credential and not isinstance(credential, (AsyncTokenCredential)): - raise ValueError("credential must be an AsyncTokenCredential for async connections") - - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = await get_entra_conninfo_async(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] - return await super().connect(*args, **kwargs) \ No newline at end of file diff --git a/python/src/azurepg_entra/sqlalchemy/__init__.py b/python/src/azurepg_entra/sqlalchemy/__init__.py index c7922b5..18dafd8 100644 --- a/python/src/azurepg_entra/sqlalchemy/__init__.py +++ b/python/src/azurepg_entra/sqlalchemy/__init__.py @@ -2,7 +2,7 @@ """ SQLAlchemy integration for Azure PostgreSQL with Entra ID authentication. -This module provides seamless integration between SQLAlchemy and Azure Entra ID +This module provides integration between SQLAlchemy and Azure Entra ID authentication for PostgreSQL connections. It automatically handles token acquisition and credential injection through SQLAlchemy's event system. @@ -26,10 +26,8 @@ enable_entra_authentication_async: Enable Entra ID auth for asynchronous SQLAlchemy engines """ -from .sqlalchemy_entra_id_extension import ( - enable_entra_authentication, - enable_entra_authentication_async, -) +from .entra_connection import enable_entra_authentication +from .async_entra_connection import enable_entra_authentication_async __all__ = [ "enable_entra_authentication", diff --git a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py new file mode 100644 index 0000000..88b5388 --- /dev/null +++ b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py @@ -0,0 +1,44 @@ +from sqlalchemy.ext.asyncio import AsyncEngine +from azure.core.credentials_async import AsyncTokenCredential +from sqlalchemy import event +from azurepg_entra.errors import CredentialValueError, TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, ScopePermissionError +from azurepg_entra.core import get_entra_conninfo + +def enable_entra_authentication_async(engine: AsyncEngine): + """ + Enable Azure Entra ID authentication for an async SQLAlchemy engine. + + This function registers an event listener that automatically provides + Entra ID credentials for each database connection if they are not already set. + + Args: + engine: The async SQLAlchemy Engine to enable Entra authentication for + """ + + @event.listens_for(engine.sync_engine, "do_connect") + def provide_token_async(dialect, conn_rec, cargs, cparams): + """Event handler that provides Entra credentials for each async connection. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + credential = cparams.get("credential", None) + if credential and not isinstance(credential, (AsyncTokenCredential)): + raise CredentialValueError("credential must be an AsyncTokenCredential for async connections") + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + print(repr(e)) + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] \ No newline at end of file diff --git a/python/src/azurepg_entra/sqlalchemy/entra_connection.py b/python/src/azurepg_entra/sqlalchemy/entra_connection.py new file mode 100644 index 0000000..0391ed3 --- /dev/null +++ b/python/src/azurepg_entra/sqlalchemy/entra_connection.py @@ -0,0 +1,43 @@ +from sqlalchemy import Engine, event +from azure.core.credentials import TokenCredential +from azurepg_entra.errors import CredentialValueError, TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, ScopePermissionError +from azurepg_entra.core import get_entra_conninfo + +def enable_entra_authentication(engine: Engine): + """ + Enable Azure Entra ID authentication for a SQLAlchemy engine. + + This function registers an event listener that automatically provides + Entra ID credentials for each database connection if they are not already set. + + Args: + engine: The SQLAlchemy Engine to enable Entra authentication for + """ + + @event.listens_for(engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + """Event handler that provides Entra credentials for each connection. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + credential = cparams.get("credential", None) + if credential and not isinstance(credential, (TokenCredential)): + raise CredentialValueError("credential must be a TokenCredential for sync connections") + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + print(repr(e)) + raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] \ No newline at end of file diff --git a/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py b/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py deleted file mode 100644 index 81a3a21..0000000 --- a/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -import asyncio -import logging -import sys -from sqlalchemy import Engine, event -from sqlalchemy.ext.asyncio import AsyncEngine -from typing import Optional -from azure.core.credentials import TokenCredential -from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async - -logger = logging.getLogger(__name__) - -def enable_entra_authentication(engine: Engine, credential: Optional[TokenCredential] = None): - """ - Enable Azure Entra ID authentication for a SQLAlchemy engine. - - This function registers an event listener that automatically provides - Entra ID credentials for each database connection if they are not already set. - - Args: - engine: The SQLAlchemy Engine to enable Entra authentication for - credential: Optional Azure credential. If None, uses DefaultAzureCredential - """ - - @event.listens_for(engine, "do_connect") - def provide_token(dialect, conn_rec, cargs, cparams): - """Event handler that provides Entra credentials for each connection.""" - try: - # Check if credentials are already present - has_user = "user" in cparams - has_password = "password" in cparams - - # Only get Entra credentials if user or password is missing - if not has_user or not has_password: - entra_creds = get_entra_conninfo(credential) - - # Only update missing credentials - if not has_user and "user" in entra_creds: - cparams["user"] = entra_creds["user"] - if not has_password and "password" in entra_creds: - cparams["password"] = entra_creds["password"] - - logger.debug(f"Provided Entra credentials for user: {entra_creds.get('user', 'unknown')}") - else: - logger.debug("User and password already present, skipping Entra authentication") - except Exception as e: - logger.error(f"Failed to get Entra credentials: {e}") - raise - - -def enable_entra_authentication_async(engine: AsyncEngine, credential: Optional[TokenCredential] = None): - """ - Enable Azure Entra ID authentication for an async SQLAlchemy engine. - - This function registers an event listener that automatically provides - Entra ID credentials for each database connection if they are not already set. - - Args: - engine: The async SQLAlchemy Engine to enable Entra authentication for - credential: Optional Azure credential. If None, uses DefaultAzureCredential - """ - - @event.listens_for(engine.sync_engine, "do_connect") - def provide_token_async(dialect, conn_rec, cargs, cparams): - """Event handler that provides Entra credentials for each async connection.""" - try: - # Check if credentials are already present - has_user = "user" in cparams - has_password = "password" in cparams - - # Only get Entra credentials if user or password is missing - if not has_user or not has_password: - # For async engines, we need to handle the async credential fetching - try: - # Try to get the current event loop - asyncio.get_running_loop() - # If we're in a running loop, we need to run the async function in a thread - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, get_entra_conninfo_async(credential)) - entra_creds = future.result() - except RuntimeError: - # No running event loop, we can use asyncio.run directly - # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - entra_creds = asyncio.run(get_entra_conninfo_async(credential)) - - logger.debug("Successfully obtained async Entra credentials") - - # Only update missing credentials - if not has_user and "user" in entra_creds: - cparams["user"] = entra_creds["user"] - if not has_password and "password" in entra_creds: - cparams["password"] = entra_creds["password"] - - logger.debug(f"Provided async Entra credentials for user: {entra_creds.get('user', 'unknown')}") - else: - logger.debug("User and password already present, skipping Entra authentication") - except Exception as e: - logger.error(f"Failed to get async Entra credentials: {e}") - raise \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py index c52a351..11fe1ba 100644 --- a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import jwt import pytest -from unittest.mock import Mock, patch +from unittest.mock import patch from psycopg2.extensions import parse_dsn, make_dsn def create_test_token(payload): @@ -9,9 +9,9 @@ def create_test_token(payload): return jwt.encode(payload, key="", algorithm="none") -class TestSyncEntraConnection: +class TestEntraConnection: def test_dsn_processing_adds_entra_credentials(self): - """Test that SyncEntraConnection logic correctly merges Entra credentials into DSN.""" + """Test that EntraConnection logic correctly merges Entra credentials into DSN.""" payload = {"upn": "user@example.com"} token = create_test_token(payload) diff --git a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py index 47f7777..5e436e9 100644 --- a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py @@ -3,10 +3,11 @@ from unittest.mock import AsyncMock, Mock, patch from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential +from azurepg_entra.errors import CredentialValueError from azurepg_entra.psycopg3 import ( AsyncEntraConnection, - SyncEntraConnection, + EntraConnection, ) class TestSyncConnection: @@ -18,7 +19,7 @@ def test_connect_with_existing_credentials(self): mock_connection = Mock() mock_connect.return_value = mock_connection - result = SyncEntraConnection.connect(**kwargs) + result = EntraConnection.connect(**kwargs) assert result == mock_connection call_args = mock_connect.call_args[1] @@ -30,13 +31,13 @@ def test_connect_with_entra_credential(self): mock_credential = Mock(spec=TokenCredential) kwargs = {"host": "localhost", "credential": mock_credential} - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo', + with patch('azurepg_entra.psycopg3.entra_connection.get_entra_conninfo', return_value={"user": "test@example.com", "password": "token123"}): with patch('psycopg.Connection.connect') as mock_connect: mock_connection = Mock() mock_connect.return_value = mock_connection - result = SyncEntraConnection.connect(**kwargs) + result = EntraConnection.connect(**kwargs) assert result == mock_connection call_args = mock_connect.call_args[1] @@ -44,9 +45,9 @@ def test_connect_with_entra_credential(self): assert call_args["password"] == "token123" def test_connect_invalid_credential_type_throws(self): - """Test that invalid credential type raises ValueError.""" - with pytest.raises(ValueError, match="credential must be a TokenCredential for sync connections"): - SyncEntraConnection.connect(host="localhost", credential="invalid") + """Test that invalid credential type raises CredentialValueError.""" + with pytest.raises(CredentialValueError, match="credential must be a TokenCredential for sync connections"): + EntraConnection.connect(host="localhost", credential="invalid") class TestAsyncConnection: @@ -72,8 +73,8 @@ async def test_connect_with_entra_credential(self): mock_credential = AsyncMock(spec=AsyncTokenCredential) kwargs = {"host": "localhost", "credential": mock_credential} - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo_async', - return_value={"user": "test@example.com", "password": "token123"}): + with patch('azurepg_entra.psycopg3.async_entra_connection.get_entra_conninfo_async', + new_callable=AsyncMock, return_value={"user": "test@example.com", "password": "token123"}): with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_connect: mock_connection = Mock() mock_connect.return_value = mock_connection @@ -87,8 +88,8 @@ async def test_connect_with_entra_credential(self): @pytest.mark.asyncio async def test_connect_invalid_credential_type_throws(self): - """Test that invalid credential type raises ValueError (async).""" - with pytest.raises(ValueError, match="credential must be an AsyncTokenCredential for async connections"): + """Test that invalid credential type raises CredentialValueError (async).""" + with pytest.raises(CredentialValueError, match="credential must be an AsyncTokenCredential for async connections"): await AsyncEntraConnection.connect(host="localhost", credential="invalid") diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py index a4f1bf5..75eb282 100644 --- a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py @@ -28,7 +28,7 @@ def decorator(func): return decorator with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo') as mock_get_creds: + with patch('azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo') as mock_get_creds: mock_get_creds.return_value = {"user": "test@example.com", "password": "test_token"} from azurepg_entra.sqlalchemy import enable_entra_authentication @@ -57,7 +57,7 @@ def decorator(func): return decorator with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo') as mock_get_creds: + with patch('azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo') as mock_get_creds: from azurepg_entra.sqlalchemy import enable_entra_authentication enable_entra_authentication(mock_engine) @@ -99,8 +99,8 @@ def decorator(func): return decorator with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo_async') as mock_get_creds_async: - mock_get_creds_async.return_value = {"user": "test@example.com", "password": "test_token"} + with patch('azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo') as mock_get_creds: + mock_get_creds.return_value = {"user": "test@example.com", "password": "test_token"} from azurepg_entra.sqlalchemy import enable_entra_authentication_async enable_entra_authentication_async(mock_async_engine) @@ -109,7 +109,8 @@ def decorator(func): mock_cparams = {} captured_handler(None, None, None, mock_cparams) - # Verify credentials were added (asyncio.run is always called for async credential fetching) + # Verify credentials were added + mock_get_creds.assert_called_once_with(None) assert mock_cparams["user"] == "test@example.com" assert mock_cparams["password"] == "test_token" @@ -129,7 +130,7 @@ def decorator(func): return decorator with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo_async') as mock_get_creds_async: + with patch('azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo') as mock_get_creds: from azurepg_entra.sqlalchemy import enable_entra_authentication_async enable_entra_authentication_async(mock_async_engine) @@ -137,8 +138,8 @@ def decorator(func): mock_cparams = {"user": "existing@example.com", "password": "existing_password"} captured_handler(None, None, None, mock_cparams) - # Verify get_entra_conninfo_async was not called (credentials already exist) - mock_get_creds_async.assert_not_called() + # Verify get_entra_conninfo was not called (credentials already exist) + mock_get_creds.assert_not_called() assert mock_cparams["user"] == "existing@example.com" assert mock_cparams["password"] == "existing_password" diff --git a/python/tests/azure/data/postgresql/test_core_functionality.py b/python/tests/azure/data/postgresql/test_core_functionality.py index 38452bc..3b8a4cd 100644 --- a/python/tests/azure/data/postgresql/test_core_functionality.py +++ b/python/tests/azure/data/postgresql/test_core_functionality.py @@ -1,9 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. -import jwt +import json +import base64 import pytest from unittest.mock import AsyncMock, Mock, patch from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential +from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError from azurepg_entra.core import ( decode_jwt, @@ -13,8 +15,13 @@ ) def create_test_token(payload): - """Helper to create a test JWT token.""" - return jwt.encode(payload, key="", algorithm="none") + """Helper to create a test JWT token manually.""" + # Create a simple JWT-like token with header.payload.signature format + header = {"alg": "none", "typ": "JWT"} + header_encoded = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip('=') + payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + signature = "" + return f"{header_encoded}.{payload_encoded}.{signature}" class TestJwtParsing: def test_decode_jwt_with_upn(self): @@ -29,9 +36,9 @@ def test_decode_jwt_with_preferred_username(self): result = decode_jwt(token) assert result == payload - def test_decode_jwt_invalid_format_returns_none(self): - result = decode_jwt("invalid.token") - assert result is None + def test_decode_jwt_invalid_format_raises_exception(self): + with pytest.raises(TokenDecodeError, match="Invalid JWT token format"): + decode_jwt("invalid.token") def test_parse_principal_name_valid_path(self): path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" @@ -59,8 +66,9 @@ def test_get_entra_conninfo_no_username_throws(self): payload = {"sub": "subject123"} token = create_test_token(payload) + # Mock both the DB token and the management token to have no username claims with patch('azurepg_entra.core.get_entra_token', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): + with pytest.raises(UsernameExtractionError, match="Could not determine username from token claims"): get_entra_conninfo(mock_credential) @pytest.mark.asyncio @@ -79,8 +87,9 @@ async def test_get_entra_conninfo_async_no_username_throws(self): payload = {"sub": "subject123"} token = create_test_token(payload) + # Mock both the DB token and the management token to have no username claims with patch('azurepg_entra.core.get_entra_token_async', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): + with pytest.raises(UsernameExtractionError, match="Could not determine username from token claims"): await get_entra_conninfo_async(mock_credential) From e3987eb1678d9562c636d98e4045dca4f89e8967 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Tue, 7 Oct 2025 23:10:09 -0700 Subject: [PATCH 06/19] Add ruff, mypy, and GitHub workflow for PR --- .github/workflows/pr.yml | 41 ++++++- python/pyproject.toml | 19 ++- .../create_db_connection_psycopg2.py | 18 +-- .../create_db_connection_psycopg3.py | 48 ++++---- .../create_db_connection_sqlalchemy.py | 63 ++++++---- python/src/azurepg_entra/__init__.py | 4 +- python/src/azurepg_entra/core.py | 70 ++++++++--- python/src/azurepg_entra/errors.py | 13 +- python/src/azurepg_entra/psycopg2/__init__.py | 10 +- .../psycopg2/entra_connection.py | 63 ++++++---- python/src/azurepg_entra/psycopg3/__init__.py | 20 ++- .../psycopg3/async_entra_connection.py | 41 +++++-- .../psycopg3/entra_connection.py | 39 ++++-- python/src/azurepg_entra/py.typed | 0 .../src/azurepg_entra/sqlalchemy/__init__.py | 10 +- .../sqlalchemy/async_entra_connection.py | 56 +++++++-- .../sqlalchemy/entra_connection.py | 47 ++++++-- .../test_psycopg2_entra_id_extension.py | 30 +++-- .../test_psycopg3_entra_id_extension.py | 80 +++++++----- .../test_sqlalchemy_entra_id_extension.py | 114 ++++++++++++------ .../postgresql/test_core_functionality.py | 48 +++++--- 21 files changed, 565 insertions(+), 269 deletions(-) create mode 100644 python/src/azurepg_entra/py.typed diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index f0e7e46..df37125 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -22,10 +22,43 @@ jobs: # Install dependencies - name: Install dependencies + working-directory: ./python run: | python -m pip install --upgrade pip - pip install .[all] + pip install -e .[all] - # Run mypy (type checker) - - name: Run mypy - run: mypy ./python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py \ No newline at end of file + # Run Ruff linter + - name: Run Ruff linting + working-directory: ./python + run: | + ruff check . + + # Run Ruff formatter check + - name: Check Ruff formatting + working-directory: ./python + run: | + ruff format --check . + + # Run mypy on source code (strict) + - name: Run mypy on source code + working-directory: ./python + run: | + mypy --strict src/ + + # Run mypy on samples (strict) + - name: Run mypy on samples + working-directory: ./python + run: | + mypy --strict samples/ + + # Run mypy on tests (basic) + - name: Run mypy on tests + working-directory: ./python + run: | + mypy tests/ + + # Run pytest + - name: Run tests + working-directory: ./python + run: | + pytest tests/ -v --tb=short \ No newline at end of file diff --git a/python/pyproject.toml b/python/pyproject.toml index 1d2be75..09d2296 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -45,7 +45,8 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "python-dotenv>=1.0.0", - "mypy ~= 1.15" + "mypy ~= 1.15", + "ruff>=0.8.0" ] # All optional dependencies combined @@ -57,7 +58,8 @@ all = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "python-dotenv>=1.0.0", - "mypy ~= 1.15" + "mypy ~= 1.15", + "ruff>=0.8.0" ] [tool.setuptools] @@ -71,6 +73,19 @@ include = ["azurepg_entra*"] Homepage = "https://github.com/v-anarendra_microsoft/entra-id-integration-for-drivers" Issues = "https://github.com/v-anarendra_microsoft/entra-id-integration-for-drivers/issues" +# Ruff configuration +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "UP", "B", "I", "N"] +ignore = ["N806"] # Allow mixed case variable names + +[tool.ruff.lint.per-file-ignores] +"tests/**/*" = ["N"] +"samples/**/*" = ["T201"] # Allow print statements in samples + # Installation examples: # Basic package only: pip install -e . # With psycopg3 support: pip install -e ".[psycopg3]" diff --git a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py index ba0c5ed..1b79300 100644 --- a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py +++ b/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py @@ -2,9 +2,11 @@ Sample demonstrating psycopg2 connection with synchronous Entra ID authentication for Azure PostgreSQL. """ -from dotenv import load_dotenv import os + +from dotenv import load_dotenv from psycopg2 import pool + from azurepg_entra.psycopg2 import EntraConnection # Load environment variables from .env file @@ -12,7 +14,8 @@ SERVER = os.getenv("POSTGRES_SERVER") DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") -def main_sync(): + +def main_sync() -> None: try: # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. # This class is applied whenever the connection pool creates a new connection, ensuring that Entra @@ -24,19 +27,19 @@ def main_sync(): maxconn=5, host=SERVER, database=DATABASE, - connection_factory=EntraConnection + connection_factory=EntraConnection, ) # Get a connection from the pool conn = connection_pool.getconn() - + try: with conn.cursor() as cur: # Query 1 cur.execute("SELECT now()") result = cur.fetchone() print(f"Database time: {result[0]}") - + # Query 2 cur.execute("SELECT current_user") user = cur.fetchone() @@ -45,10 +48,11 @@ def main_sync(): # Return connection to pool connection_pool.putconn(conn) connection_pool.closeall() - + except Exception as e: print(f"Sync - Error connecting to database: {e}") raise + if __name__ == "__main__": - main_sync() \ No newline at end of file + main_sync() diff --git a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py index c5c0b80..eebf337 100644 --- a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py +++ b/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py @@ -1,36 +1,39 @@ """ -Sample demonstrating both synchronous and asynchronous psycopg connections +Sample demonstrating both synchronous and asynchronous psycopg connections with Azure Entra ID authentication for Azure PostgreSQL. """ -from psycopg_pool import AsyncConnectionPool, ConnectionPool -from dotenv import load_dotenv import argparse import asyncio -import sys import os -from azurepg_entra.psycopg3 import EntraConnection, AsyncEntraConnection +import sys + +from dotenv import load_dotenv +from psycopg_pool import AsyncConnectionPool, ConnectionPool + +from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection # Load environment variables from .env file load_dotenv() SERVER = os.getenv("POSTGRES_SERVER") DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") -def main_sync(): + +def main_sync() -> None: """Synchronous connection example using psycopg with Entra ID authentication.""" try: # We use the SyncEntraConnection class to enable synchronous Entra-based authentication for database access. # This class is applied whenever the connection pool creates a new connection, ensuring that Entra # authentication tokens are properly managed and refreshed so that each connection uses a valid token. - # + # # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect pool = ConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, max_size=5, open=False, - connection_class=EntraConnection + connection_class=EntraConnection, ) pool.open() with pool, pool.connection() as conn, conn.cursor() as cur: @@ -38,16 +41,17 @@ def main_sync(): cur.execute("SELECT now()") result = cur.fetchone() print(f"Sync - Database time: {result}") - + # Query 2 cur.execute("SELECT current_user") user = cur.fetchone() - print(f"Sync - Connected as: {user[0]}") + print(f"Sync - Connected as: {user[0] if user else 'Unknown'}") except Exception as e: print(f"Sync - Error connecting to database: {e}") raise -async def main_async(): + +async def main_async() -> None: """Asynchronous connection example using psycopg with Entra ID authentication.""" try: @@ -61,7 +65,7 @@ async def main_async(): min_size=1, max_size=5, open=False, - connection_class=AsyncEntraConnection + connection_class=AsyncEntraConnection, ) await pool.open() async with pool, pool.connection() as conn, conn.cursor() as cur: @@ -73,14 +77,15 @@ async def main_async(): # Query 2 await cur.execute("SELECT current_user") user = await cur.fetchone() - print(f"Async - Connected as: {user[0]}") + print(f"Async - Connected as: {user[0] if user else 'Unknown'}") except Exception as e: print(f"Async - Error connecting to database: {e}") raise -async def main(mode: str = "async"): + +async def main(mode: str = "async") -> None: """Main function that runs sync and/or async examples based on mode. - + Args: mode: "sync", "async", or "both" to determine which examples to run """ @@ -91,7 +96,7 @@ async def main(mode: str = "async"): print("✅ Sync example completed successfully!") except Exception as e: print(f"❌ Sync example failed: {e}") - + if mode in ("async", "both"): if mode == "both": print("\n=== Running Asynchronous Example ===") @@ -103,6 +108,7 @@ async def main(mode: str = "async"): except Exception as e: print(f"❌ Async example failed: {e}") + if __name__ == "__main__": # Parse command line arguments parser = argparse.ArgumentParser( @@ -112,12 +118,12 @@ async def main(mode: str = "async"): "--mode", choices=["sync", "async", "both"], default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" + help="Run synchronous, asynchronous, or both examples (default: both)", ) args = parser.parse_args() - + # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main(args.mode)) \ No newline at end of file + + asyncio.run(main(args.mode)) diff --git a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py index b9fd360..f623bb3 100644 --- a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py +++ b/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py @@ -1,29 +1,35 @@ """ -Sample demonstrating both synchronous and asynchronous SQLAlchemy connections +Sample demonstrating both synchronous and asynchronous SQLAlchemy connections with Azure Entra ID authentication for Azure PostgreSQL. """ -from sqlalchemy import create_engine, text -from sqlalchemy.ext.asyncio import create_async_engine -from azurepg_entra.sqlalchemy import enable_entra_authentication, enable_entra_authentication_async -from dotenv import load_dotenv import argparse import asyncio -import sys import os +import sys + +from dotenv import load_dotenv +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +from azurepg_entra.sqlalchemy import ( + enable_entra_authentication, + enable_entra_authentication_async, +) # Load environment variables from .env file load_dotenv() SERVER = os.getenv("POSTGRES_SERVER") DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") -def main_sync(): + +def main_sync() -> None: """Synchronous connection example using SQLAlchemy with Entra ID authentication.""" try: # Create a synchronous engine engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") - + # We add an event listener to the engine to enable synchronous Entra authentication # for database access. This event listener is triggered whenever the connection pool # backing the engine creates a new connection, ensuring that Entra authentication tokens @@ -35,25 +41,28 @@ def main_sync(): with engine.connect() as conn: # Query 1 result = conn.execute(text("SELECT now()")) - print(f"Sync - Database time: {result.fetchone()[0]}") - + row = result.fetchone() + print(f"Sync - Database time: {row[0] if row else 'Unknown'}") + # Query 2 result = conn.execute(text("SELECT current_user")) - print(f"Sync - Connected as: {result.fetchone()[0]}") - + row = result.fetchone() + print(f"Sync - Connected as: {row[0] if row else 'Unknown'}") + # Clean up the engine engine.dispose() except Exception as e: print(f"Sync - Error connecting to database: {e}") raise -async def main_async(): + +async def main_async() -> None: """Asynchronous connection example using SQLAlchemy with Entra ID authentication.""" try: # Create an asynchronous engine engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") - + # We add an event listener to the engine to enable asynchronous Entra authentication # for database access. This event listener is triggered whenever the connection pool # backing the engine creates a new connection, ensuring that Entra authentication tokens @@ -65,21 +74,24 @@ async def main_async(): async with engine.connect() as conn: # Query 1 result = await conn.execute(text("SELECT now()")) - print(f"Async - Database time: {result.fetchone()[0]}") + row = result.fetchone() + print(f"Async - Database time: {row[0] if row else 'Unknown'}") # Query 2 result = await conn.execute(text("SELECT current_user")) - print(f"Async - Connected as: {result.fetchone()[0]}") - + row = result.fetchone() + print(f"Async - Connected as: {row[0] if row else 'Unknown'}") + # Clean up the engine await engine.dispose() except Exception as e: print(f"Async - Error connecting to database: {e}") raise -async def main(mode: str = "async"): + +async def main(mode: str = "async") -> None: """Main function that runs sync and/or async examples based on mode. - + Args: mode: "sync", "async", or "both" to determine which examples to run """ @@ -90,7 +102,7 @@ async def main(mode: str = "async"): print("✅ Sync example completed successfully!") except Exception as e: print(f"❌ Sync example failed: {e}") - + if mode in ("async", "both"): if mode == "both": print("\n=== Running Asynchronous SQLAlchemy Example ===") @@ -102,6 +114,7 @@ async def main(mode: str = "async"): except Exception as e: print(f"❌ Async example failed: {e}") + if __name__ == "__main__": # Parse command line arguments parser = argparse.ArgumentParser( @@ -111,12 +124,12 @@ async def main(mode: str = "async"): "--mode", choices=["sync", "async", "both"], default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" + help="Run synchronous, asynchronous, or both examples (default: both)", ) args = parser.parse_args() - + # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main(args.mode)) \ No newline at end of file + + asyncio.run(main(args.mode)) diff --git a/python/src/azurepg_entra/__init__.py b/python/src/azurepg_entra/__init__.py index 4852663..5d3a282 100644 --- a/python/src/azurepg_entra/__init__.py +++ b/python/src/azurepg_entra/__init__.py @@ -2,7 +2,7 @@ """ Azure PostgreSQL Entra ID Integration Library -This library provides connection classes for using Azure Entra ID authentication +This library provides connection classes for using Azure Entra ID authentication with Azure Database for PostgreSQL across different PostgreSQL drivers. Available modules (with optional dependencies): @@ -16,4 +16,4 @@ """ __version__ = "0.1.0" -__author__ = "Microsoft Corporation" \ No newline at end of file +__author__ = "Microsoft Corporation" diff --git a/python/src/azurepg_entra/core.py b/python/src/azurepg_entra/core.py index 8ebfcc8..ee1844c 100644 --- a/python/src/azurepg_entra/core.py +++ b/python/src/azurepg_entra/core.py @@ -1,22 +1,29 @@ import base64 import json from typing import Any, cast + from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import ClientAuthenticationError +from azure.identity import CredentialUnavailableError from azure.identity import DefaultAzureCredential as DefaultAzureCredential from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential -from azure.identity import CredentialUnavailableError -from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, ScopePermissionError + +from azurepg_entra.errors import ( + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" AZURE_MANAGEMENT_SCOPE = "https://management.azure.com/.default" + def get_entra_token(credential: TokenCredential | None, scope: str) -> str: """Acquires an Entra authentication token for Azure PostgreSQL synchronously. Parameters: - credential (TokenCredential or None): Credential object used to obtain the token. + credential (TokenCredential or None): Credential object used to obtain the token. If None, the default Azure credentials are used. scope (str): The scope for the token request. @@ -27,7 +34,10 @@ def get_entra_token(credential: TokenCredential | None, scope: str) -> str: cred = credential.get_token(scope) return cred.token -async def get_entra_token_async(credential: AsyncTokenCredential | None, scope: str) -> str: + +async def get_entra_token_async( + credential: AsyncTokenCredential | None, scope: str +) -> str: """Asynchronously acquires an Entra authentication token for Azure PostgreSQL. Parameters: @@ -42,7 +52,8 @@ async def get_entra_token_async(credential: AsyncTokenCredential | None, scope: async with credential: cred = await credential.get_token(scope) return cred.token - + + def decode_jwt(token: str) -> dict[str, Any]: """Decodes a JWT token to extract its payload claims. @@ -51,7 +62,7 @@ def decode_jwt(token: str) -> dict[str, Any]: Returns: dict[str, Any]: A dictionary containing the claims extracted from the token payload. - + Raises: TokenValueError: If the token format is invalid or cannot be decoded. """ @@ -63,6 +74,7 @@ def decode_jwt(token: str) -> dict[str, Any]: except Exception as e: raise TokenDecodeError("Invalid JWT token format") from e + def parse_principal_name(xms_mirid: str) -> str | None: """Parses the principal name from an Azure resource path. @@ -74,21 +86,24 @@ def parse_principal_name(xms_mirid: str) -> str | None: """ if not xms_mirid: return None - + # Parse the xms_mirid claim which looks like # /subscriptions/{subId}/resourcegroups/{resourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{principalName} - last_slash_index = xms_mirid.rfind('/') + last_slash_index = xms_mirid.rfind("/") if last_slash_index == -1: return None beginning = xms_mirid[:last_slash_index] - principal_name = xms_mirid[last_slash_index + 1:] + principal_name = xms_mirid[last_slash_index + 1 :] - if not principal_name or not beginning.lower().endswith("providers/microsoft.managedidentity/userassignedidentities"): + if not principal_name or not beginning.lower().endswith( + "providers/microsoft.managedidentity/userassignedidentities" + ): return None return principal_name + def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: """Synchronously obtains connection information from Entra authentication for Azure PostgreSQL. @@ -103,7 +118,7 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: dict[str, str]: A dictionary with 'user' and 'password' keys, where: - 'user': The extracted username from token claims - 'password': The Entra ID access token for database authentication - + Raises: TokenDecodeError: If the JWT token cannot be decoded or is malformed. UsernameExtractionError: If the username cannot be extracted from token claims. @@ -119,7 +134,9 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: raise xms_mirid = db_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or db_claims.get("upn") or db_claims.get("preferred_username") or db_claims.get("unique_name") @@ -130,14 +147,18 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: try: mgmt_token = get_entra_token(credential, AZURE_MANAGEMENT_SCOPE) except (CredentialUnavailableError, ClientAuthenticationError) as e: - raise ScopePermissionError("Failed to acquire token from management scope") from e + raise ScopePermissionError( + "Failed to acquire token from management scope" + ) from e try: mgmt_claims = decode_jwt(mgmt_token) except TokenDecodeError: raise xms_mirid = mgmt_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or mgmt_claims.get("upn") or mgmt_claims.get("preferred_username") or mgmt_claims.get("unique_name") @@ -151,7 +172,10 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: return {"user": username, "password": db_token} -async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> dict[str, str]: + +async def get_entra_conninfo_async( + credential: AsyncTokenCredential | None, +) -> dict[str, str]: """Asynchronously obtains connection information from Entra authentication for Azure PostgreSQL. This function acquires an access token from Azure Entra ID and extracts the username @@ -165,7 +189,7 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d dict[str, str]: A dictionary with 'user' and 'password' keys, where: - 'user': The extracted username from token claims - 'password': The Entra ID access token for database authentication - + Raises: TokenDecodeError: If the JWT token cannot be decoded or is malformed. UsernameExtractionError: If the username cannot be extracted from token claims. @@ -180,7 +204,9 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d raise xms_mirid = db_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or db_claims.get("upn") or db_claims.get("preferred_username") or db_claims.get("unique_name") @@ -190,14 +216,18 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d try: mgmt_token = await get_entra_token_async(credential, AZURE_MANAGEMENT_SCOPE) except (CredentialUnavailableError, ClientAuthenticationError) as e: - raise ScopePermissionError("Failed to acquire token from management scope") from e + raise ScopePermissionError( + "Failed to acquire token from management scope" + ) from e try: mgmt_claims = decode_jwt(mgmt_token) except TokenDecodeError: raise xms_mirid = mgmt_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or mgmt_claims.get("upn") or mgmt_claims.get("preferred_username") or mgmt_claims.get("unique_name") @@ -209,4 +239,4 @@ async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> d "Ensure the identity has the proper Azure AD attributes." ) - return {"user": username, "password": db_token} \ No newline at end of file + return {"user": username, "password": db_token} diff --git a/python/src/azurepg_entra/errors.py b/python/src/azurepg_entra/errors.py index ae87bdf..6249a29 100644 --- a/python/src/azurepg_entra/errors.py +++ b/python/src/azurepg_entra/errors.py @@ -1,23 +1,34 @@ class EntraIdBaseError(Exception): """Base class for all custom exceptions in the project.""" + pass + class TokenDecodeError(EntraIdBaseError): """Raised when a token value is invalid.""" + pass + class UsernameExtractionError(EntraIdBaseError): """Raised when username cannot be extracted from token.""" + pass + class CredentialValueError(EntraIdBaseError): """Raised when token credential is invalid.""" + pass + class EntraConnectionValueError(EntraIdBaseError): """Raised when Entra connection credentials are invalid.""" + pass + class ScopePermissionError(EntraIdBaseError): """Raised when the provided scope does not have sufficient permissions.""" - pass \ No newline at end of file + + pass diff --git a/python/src/azurepg_entra/psycopg2/__init__.py b/python/src/azurepg_entra/psycopg2/__init__.py index ad20de5..ec25784 100644 --- a/python/src/azurepg_entra/psycopg2/__init__.py +++ b/python/src/azurepg_entra/psycopg2/__init__.py @@ -7,7 +7,7 @@ Requirements: Install with: pip install azurepg-entra[psycopg2] - + This will install: - psycopg2-binary>=2.8.0 @@ -17,7 +17,7 @@ Example usage: # Synchronous connection from azurepg_entra.psycopg2 import EntraConnection - + connection_pool = pool.ThreadedConnectionPool( minconn=1, maxconn=5, @@ -31,14 +31,14 @@ from .entra_connection import ( EntraConnection, ) - + __all__ = [ "EntraConnection", ] - + except ImportError as e: # Provide a helpful error message if psycopg2 dependencies are missing raise ImportError( "psycopg2 dependencies are not installed. " "Install them with: pip install azurepg-entra[psycopg2]" - ) from e \ No newline at end of file + ) from e diff --git a/python/src/azurepg_entra/psycopg2/entra_connection.py b/python/src/azurepg_entra/psycopg2/entra_connection.py index 7a31606..82722f7 100644 --- a/python/src/azurepg_entra/psycopg2/entra_connection.py +++ b/python/src/azurepg_entra/psycopg2/entra_connection.py @@ -1,14 +1,24 @@ # Copyright (c) Microsoft. All rights reserved. -from psycopg2.extensions import connection, parse_dsn, make_dsn -from azurepg_entra.core import get_entra_conninfo -from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, CredentialValueError, ScopePermissionError +from typing import Any + from azure.core.credentials import TokenCredential +from psycopg2.extensions import connection, make_dsn, parse_dsn + +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) + class EntraConnection(connection): """Establishes a synchronous PostgreSQL connection using Entra authentication. - This connection class automatically acquires Azure Entra ID credentials when user - or password are not provided in the DSN or connection parameters. Authentication + This connection class automatically acquires Azure Entra ID credentials when user + or password are not provided in the DSN or connection parameters. Authentication errors are printed to console for debugging purposes. Parameters: @@ -23,37 +33,46 @@ class EntraConnection(connection): CredentialValueError: If the provided credential is not a valid TokenCredential. EntraConnectionValueError: If Entra connection credentials cannot be retrieved """ - def __init__(self, dsn, **kwargs): + + def __init__(self, dsn: str, **kwargs: Any) -> None: # Extract current DSN params dsn_params = parse_dsn(dsn) if dsn else {} credential = kwargs.pop("credential", None) if credential and not isinstance(credential, (TokenCredential)): - raise CredentialValueError("credential must be a TokenCredential for sync connections") - + raise CredentialValueError( + "credential must be a TokenCredential for sync connections" + ) + # Check if user and password are already provided - has_user = 'user' in dsn_params or 'user' in kwargs - has_password = 'password' in dsn_params or 'password' in kwargs - + has_user = "user" in dsn_params or "user" in kwargs + has_password = "password" in dsn_params or "password" in kwargs + # Only get Entra credentials if user or password is missing if not has_user or not has_password: try: entra_creds = get_entra_conninfo(credential) - except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + except ( + TokenDecodeError, + UsernameExtractionError, + ScopePermissionError, + ) as e: print(repr(e)) - raise EntraConnectionValueError("Could not retrieve Entra credentials") from e - + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e + # Only update missing credentials - if not has_user and 'user' in entra_creds: - dsn_params['user'] = entra_creds['user'] - if not has_password and 'password' in entra_creds: - dsn_params['password'] = entra_creds['password'] - + if not has_user and "user" in entra_creds: + dsn_params["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + dsn_params["password"] = entra_creds["password"] + # Update DSN params with any kwargs (kwargs take precedence) dsn_params.update(kwargs) - + # Create new DSN with updated credentials new_dsn = make_dsn(**dsn_params) - + # Call parent constructor with updated DSN only - super().__init__(new_dsn) \ No newline at end of file + super().__init__(new_dsn) diff --git a/python/src/azurepg_entra/psycopg3/__init__.py b/python/src/azurepg_entra/psycopg3/__init__.py index 831b089..879889c 100644 --- a/python/src/azurepg_entra/psycopg3/__init__.py +++ b/python/src/azurepg_entra/psycopg3/__init__.py @@ -7,7 +7,7 @@ Requirements: Install with: pip install azurepg-entra[psycopg3] - + This will install: - psycopg[binary]>=3.1.0 @@ -18,30 +18,28 @@ Example usage: from azurepg_entra.psycopg3 import EntraConnection, AsyncEntraConnection from psycopg_pool import ConnectionPool, AsyncConnectionPool - + # Synchronous usage pool = ConnectionPool( conninfo="postgresql://myserver:5432/mydb", connection_class=EntraConnection ) - - # Asynchronous usage + + # Asynchronous usage async_pool = AsyncConnectionPool( - conninfo="postgresql://myserver:5432/mydb", + conninfo="postgresql://myserver:5432/mydb", connection_class=AsyncEntraConnection ) """ try: - from .entra_connection import EntraConnection from .async_entra_connection import AsyncEntraConnection - __all__ = [ - "EntraConnection", - "AsyncEntraConnection" - ] + from .entra_connection import EntraConnection + + __all__ = ["EntraConnection", "AsyncEntraConnection"] except ImportError as e: # Provide a helpful error message if psycopg dependencies are missing raise ImportError( "psycopg3 dependencies are not installed. " "Install them with: pip install azurepg-entra[psycopg3]" - ) from e \ No newline at end of file + ) from e diff --git a/python/src/azurepg_entra/psycopg3/async_entra_connection.py b/python/src/azurepg_entra/psycopg3/async_entra_connection.py index 171e315..2d1b5f2 100644 --- a/python/src/azurepg_entra/psycopg3/async_entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/async_entra_connection.py @@ -1,25 +1,34 @@ # Copyright (c) Microsoft. All rights reserved. -from psycopg import AsyncConnection +from typing import Any + from azure.core.credentials_async import AsyncTokenCredential -from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, CredentialValueError, ScopePermissionError +from psycopg import AsyncConnection + from azurepg_entra.core import get_entra_conninfo_async +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) -from typing import Any try: - from typing import Self + from typing import Self except ImportError: from typing_extensions import Self # fallback for older Python + class AsyncEntraConnection(AsyncConnection[tuple[Any, ...]]): """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" - + @classmethod async def connect(cls, *args: Any, **kwargs: Any) -> Self: """Establishes an asynchronous PostgreSQL connection using Entra authentication. - This method automatically acquires Azure Entra ID credentials when user or password - are not provided in the connection parameters. Authentication errors are printed to + This method automatically acquires Azure Entra ID credentials when user or password + are not provided in the connection parameters. Authentication errors are printed to console for debugging purposes. Parameters: @@ -38,18 +47,26 @@ async def connect(cls, *args: Any, **kwargs: Any) -> Self: """ credential = kwargs.pop("credential", None) if credential and not isinstance(credential, (AsyncTokenCredential)): - raise CredentialValueError("credential must be an AsyncTokenCredential for async connections") - + raise CredentialValueError( + "credential must be an AsyncTokenCredential for async connections" + ) + # Check if we need to acquire Entra authentication info if not kwargs.get("user") or not kwargs.get("password"): try: entra_conninfo = await get_entra_conninfo_async(credential) - except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + except ( + TokenDecodeError, + UsernameExtractionError, + ScopePermissionError, + ) as e: print(repr(e)) - raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e # Always use the token password when Entra authentication is needed kwargs["password"] = entra_conninfo["password"] if not kwargs.get("user"): # If user isn't already set, use the username from the token kwargs["user"] = entra_conninfo["user"] - return await super().connect(*args, **kwargs) \ No newline at end of file + return await super().connect(*args, **kwargs) diff --git a/python/src/azurepg_entra/psycopg3/entra_connection.py b/python/src/azurepg_entra/psycopg3/entra_connection.py index 53efb8e..3246938 100644 --- a/python/src/azurepg_entra/psycopg3/entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/entra_connection.py @@ -1,24 +1,33 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any + try: - from typing import Self + from typing import Self except ImportError: from typing_extensions import Self # fallback for older Python -from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, CredentialValueError, ScopePermissionError from azure.core.credentials import TokenCredential -from azurepg_entra.core import get_entra_conninfo from psycopg import Connection +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) + + class EntraConnection(Connection[tuple[Any, ...]]): """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" - + @classmethod def connect(cls, *args: Any, **kwargs: Any) -> Self: """Establishes a synchronous PostgreSQL connection using Entra authentication. - This method automatically acquires Azure Entra ID credentials when user or password - are not provided in the connection parameters. If authentication fails, the original + This method automatically acquires Azure Entra ID credentials when user or password + are not provided in the connection parameters. If authentication fails, the original exception is re-raised to the caller. Parameters: @@ -37,18 +46,26 @@ def connect(cls, *args: Any, **kwargs: Any) -> Self: """ credential = kwargs.pop("credential", None) if credential and not isinstance(credential, (TokenCredential)): - raise CredentialValueError("credential must be a TokenCredential for sync connections") - + raise CredentialValueError( + "credential must be a TokenCredential for sync connections" + ) + # Check if we need to acquire Entra authentication info if not kwargs.get("user") or not kwargs.get("password"): try: entra_conninfo = get_entra_conninfo(credential) - except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + except ( + TokenDecodeError, + UsernameExtractionError, + ScopePermissionError, + ) as e: print(repr(e)) - raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e # Always use the token password when Entra authentication is needed kwargs["password"] = entra_conninfo["password"] if not kwargs.get("user"): # If user isn't already set, use the username from the token kwargs["user"] = entra_conninfo["user"] - return super().connect(*args, **kwargs) \ No newline at end of file + return super().connect(*args, **kwargs) diff --git a/python/src/azurepg_entra/py.typed b/python/src/azurepg_entra/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/python/src/azurepg_entra/sqlalchemy/__init__.py b/python/src/azurepg_entra/sqlalchemy/__init__.py index 18dafd8..30aca3f 100644 --- a/python/src/azurepg_entra/sqlalchemy/__init__.py +++ b/python/src/azurepg_entra/sqlalchemy/__init__.py @@ -10,14 +10,14 @@ Synchronous engines: from sqlalchemy import create_engine from azurepg_entra.sqlalchemy import enable_entra_authentication - + engine = create_engine("postgresql://myserver.postgres.database.azure.com/mydb") enable_entra_authentication(engine) - + Asynchronous engines: from sqlalchemy.ext.asyncio import create_async_engine from azurepg_entra.sqlalchemy import enable_entra_authentication_async - + engine = create_async_engine("postgresql+asyncpg://myserver.postgres.database.azure.com/mydb") enable_entra_authentication_async(engine) @@ -26,10 +26,10 @@ enable_entra_authentication_async: Enable Entra ID auth for asynchronous SQLAlchemy engines """ -from .entra_connection import enable_entra_authentication from .async_entra_connection import enable_entra_authentication_async +from .entra_connection import enable_entra_authentication __all__ = [ "enable_entra_authentication", "enable_entra_authentication_async", -] \ No newline at end of file +] diff --git a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py index 88b5388..7b79a31 100644 --- a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py @@ -1,44 +1,74 @@ -from sqlalchemy.ext.asyncio import AsyncEngine +from typing import Any + +from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential from sqlalchemy import event -from azurepg_entra.errors import CredentialValueError, TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, ScopePermissionError +from sqlalchemy.engine import Dialect +from sqlalchemy.ext.asyncio import AsyncEngine + from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) -def enable_entra_authentication_async(engine: AsyncEngine): + +def enable_entra_authentication_async(engine: AsyncEngine) -> None: """ Enable Azure Entra ID authentication for an async SQLAlchemy engine. This function registers an event listener that automatically provides Entra ID credentials for each database connection if they are not already set. - + Args: engine: The async SQLAlchemy Engine to enable Entra authentication for """ @event.listens_for(engine.sync_engine, "do_connect") - def provide_token_async(dialect, conn_rec, cargs, cparams): + def provide_token_async( + dialect: Dialect, conn_rec: Any, cargs: Any, cparams: dict[str, Any] + ) -> None: """Event handler that provides Entra credentials for each async connection. - + Raises: CredentialValueError: If the provided credential is not a valid TokenCredential. EntraConnectionValueError: If Entra connection credentials cannot be retrieved """ credential = cparams.get("credential", None) - if credential and not isinstance(credential, (AsyncTokenCredential)): - raise CredentialValueError("credential must be an AsyncTokenCredential for async connections") + if credential and not isinstance( + credential, (AsyncTokenCredential, TokenCredential) + ): + raise CredentialValueError( + "credential must be an AsyncTokenCredential or TokenCredential for async connections" + ) # Check if credentials are already present has_user = "user" in cparams has_password = "password" in cparams - + # Only get Entra credentials if user or password is missing if not has_user or not has_password: try: - entra_creds = get_entra_conninfo(credential) - except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + # Cast to TokenCredential since SQLAlchemy events are synchronous + sync_credential: TokenCredential | None = ( + credential + if isinstance(credential, TokenCredential) or credential is None + else None + ) + entra_creds = get_entra_conninfo(sync_credential) + except ( + TokenDecodeError, + UsernameExtractionError, + ScopePermissionError, + ) as e: print(repr(e)) - raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e # Only update missing credentials if not has_user and "user" in entra_creds: cparams["user"] = entra_creds["user"] if not has_password and "password" in entra_creds: - cparams["password"] = entra_creds["password"] \ No newline at end of file + cparams["password"] = entra_creds["password"] diff --git a/python/src/azurepg_entra/sqlalchemy/entra_connection.py b/python/src/azurepg_entra/sqlalchemy/entra_connection.py index 0391ed3..365fd48 100644 --- a/python/src/azurepg_entra/sqlalchemy/entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/entra_connection.py @@ -1,43 +1,64 @@ -from sqlalchemy import Engine, event +from typing import Any + from azure.core.credentials import TokenCredential -from azurepg_entra.errors import CredentialValueError, TokenDecodeError, UsernameExtractionError, EntraConnectionValueError, ScopePermissionError +from sqlalchemy import Engine, event +from sqlalchemy.engine import Dialect + from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) + -def enable_entra_authentication(engine: Engine): +def enable_entra_authentication(engine: Engine) -> None: """ Enable Azure Entra ID authentication for a SQLAlchemy engine. - + This function registers an event listener that automatically provides Entra ID credentials for each database connection if they are not already set. - + Args: engine: The SQLAlchemy Engine to enable Entra authentication for """ - + @event.listens_for(engine, "do_connect") - def provide_token(dialect, conn_rec, cargs, cparams): + def provide_token( + dialect: Dialect, conn_rec: Any, cargs: Any, cparams: dict[str, Any] + ) -> None: """Event handler that provides Entra credentials for each connection. - + Raises: CredentialValueError: If the provided credential is not a valid TokenCredential. EntraConnectionValueError: If Entra connection credentials cannot be retrieved """ credential = cparams.get("credential", None) if credential and not isinstance(credential, (TokenCredential)): - raise CredentialValueError("credential must be a TokenCredential for sync connections") + raise CredentialValueError( + "credential must be a TokenCredential for sync connections" + ) # Check if credentials are already present has_user = "user" in cparams has_password = "password" in cparams - + # Only get Entra credentials if user or password is missing if not has_user or not has_password: try: entra_creds = get_entra_conninfo(credential) - except (TokenDecodeError, UsernameExtractionError, ScopePermissionError) as e: + except ( + TokenDecodeError, + UsernameExtractionError, + ScopePermissionError, + ) as e: print(repr(e)) - raise EntraConnectionValueError("Could not retrieve Entra credentials") from e + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e # Only update missing credentials if not has_user and "user" in entra_creds: cparams["user"] = entra_creds["user"] if not has_password and "password" in entra_creds: - cparams["password"] = entra_creds["password"] \ No newline at end of file + cparams["password"] = entra_creds["password"] diff --git a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py index 11fe1ba..f7ed013 100644 --- a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. +from unittest.mock import patch + import jwt import pytest -from unittest.mock import patch -from psycopg2.extensions import parse_dsn, make_dsn +from psycopg2.extensions import make_dsn, parse_dsn + def create_test_token(payload): """Helper to create a test JWT token.""" @@ -14,25 +16,28 @@ def test_dsn_processing_adds_entra_credentials(self): """Test that EntraConnection logic correctly merges Entra credentials into DSN.""" payload = {"upn": "user@example.com"} token = create_test_token(payload) - - with patch('azurepg_entra.core.get_entra_conninfo') as mock_get_creds: - mock_get_creds.return_value = {"user": "user@example.com", "password": token} - + + with patch("azurepg_entra.core.get_entra_conninfo") as mock_get_creds: + mock_get_creds.return_value = { + "user": "user@example.com", + "password": token, + } + from azurepg_entra.core import get_entra_conninfo - + # Test with existing DSN parameters original_dsn = "host=localhost port=5432 dbname=testdb sslmode=require" entra_creds = get_entra_conninfo(None) - + dsn_params = parse_dsn(original_dsn) if original_dsn else {} dsn_params.update(entra_creds) new_dsn = make_dsn(**dsn_params) - + mock_get_creds.assert_called_once_with(None) - + # Original params preserved assert "host=localhost" in new_dsn - assert "port=5432" in new_dsn + assert "port=5432" in new_dsn assert "dbname=testdb" in new_dsn assert "sslmode=require" in new_dsn # Entra creds added @@ -42,5 +47,6 @@ def test_dsn_processing_adds_entra_credentials(self): if __name__ == "__main__": import sys + exit_code = pytest.main([__file__, "-v", "--tb=short"]) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py index 5e436e9..f730c84 100644 --- a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py @@ -1,26 +1,32 @@ # Copyright (c) Microsoft. All rights reserved. -import pytest from unittest.mock import AsyncMock, Mock, patch + +import pytest from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential -from azurepg_entra.errors import CredentialValueError +from azurepg_entra.errors import CredentialValueError from azurepg_entra.psycopg3 import ( AsyncEntraConnection, EntraConnection, ) + class TestSyncConnection: def test_connect_with_existing_credentials(self): """Test that existing user/password credentials are used without fetching Entra credentials.""" - kwargs = {"host": "localhost", "user": "existing_user", "password": "existing_password"} - - with patch('psycopg.Connection.connect') as mock_connect: + kwargs = { + "host": "localhost", + "user": "existing_user", + "password": "existing_password", + } + + with patch("psycopg.Connection.connect") as mock_connect: mock_connection = Mock() mock_connect.return_value = mock_connection - + result = EntraConnection.connect(**kwargs) - + assert result == mock_connection call_args = mock_connect.call_args[1] assert call_args["user"] == "existing_user" @@ -30,15 +36,17 @@ def test_connect_with_entra_credential(self): """Test that Entra credentials are fetched and used when no user/password provided.""" mock_credential = Mock(spec=TokenCredential) kwargs = {"host": "localhost", "credential": mock_credential} - - with patch('azurepg_entra.psycopg3.entra_connection.get_entra_conninfo', - return_value={"user": "test@example.com", "password": "token123"}): - with patch('psycopg.Connection.connect') as mock_connect: + + with patch( + "azurepg_entra.psycopg3.entra_connection.get_entra_conninfo", + return_value={"user": "test@example.com", "password": "token123"}, + ): + with patch("psycopg.Connection.connect") as mock_connect: mock_connection = Mock() mock_connect.return_value = mock_connection - + result = EntraConnection.connect(**kwargs) - + assert result == mock_connection call_args = mock_connect.call_args[1] assert call_args["user"] == "test@example.com" @@ -46,7 +54,10 @@ def test_connect_with_entra_credential(self): def test_connect_invalid_credential_type_throws(self): """Test that invalid credential type raises CredentialValueError.""" - with pytest.raises(CredentialValueError, match="credential must be a TokenCredential for sync connections"): + with pytest.raises( + CredentialValueError, + match="credential must be a TokenCredential for sync connections", + ): EntraConnection.connect(host="localhost", credential="invalid") @@ -54,14 +65,20 @@ class TestAsyncConnection: @pytest.mark.asyncio async def test_connect_with_existing_credentials(self): """Test that existing user/password credentials are used without fetching Entra credentials (async).""" - kwargs = {"host": "localhost", "user": "existing_user", "password": "existing_password"} - - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_connect: + kwargs = { + "host": "localhost", + "user": "existing_user", + "password": "existing_password", + } + + with patch( + "psycopg.AsyncConnection.connect", new_callable=AsyncMock + ) as mock_connect: mock_connection = Mock() mock_connect.return_value = mock_connection - + result = await AsyncEntraConnection.connect(**kwargs) - + assert result == mock_connection call_args = mock_connect.call_args[1] assert call_args["user"] == "existing_user" @@ -72,15 +89,20 @@ async def test_connect_with_entra_credential(self): """Test that Entra credentials are fetched and used when no user/password provided (async).""" mock_credential = AsyncMock(spec=AsyncTokenCredential) kwargs = {"host": "localhost", "credential": mock_credential} - - with patch('azurepg_entra.psycopg3.async_entra_connection.get_entra_conninfo_async', - new_callable=AsyncMock, return_value={"user": "test@example.com", "password": "token123"}): - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_connect: + + with patch( + "azurepg_entra.psycopg3.async_entra_connection.get_entra_conninfo_async", + new_callable=AsyncMock, + return_value={"user": "test@example.com", "password": "token123"}, + ): + with patch( + "psycopg.AsyncConnection.connect", new_callable=AsyncMock + ) as mock_connect: mock_connection = Mock() mock_connect.return_value = mock_connection - + result = await AsyncEntraConnection.connect(**kwargs) - + assert result == mock_connection call_args = mock_connect.call_args[1] assert call_args["user"] == "test@example.com" @@ -89,11 +111,15 @@ async def test_connect_with_entra_credential(self): @pytest.mark.asyncio async def test_connect_invalid_credential_type_throws(self): """Test that invalid credential type raises CredentialValueError (async).""" - with pytest.raises(CredentialValueError, match="credential must be an AsyncTokenCredential for async connections"): + with pytest.raises( + CredentialValueError, + match="credential must be an AsyncTokenCredential for async connections", + ): await AsyncEntraConnection.connect(host="localhost", credential="invalid") if __name__ == "__main__": import sys + exit_code = pytest.main([__file__, "-v", "--tb=short"]) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py index 75eb282..7590f0b 100644 --- a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py @@ -1,43 +1,54 @@ # Copyright (c) Microsoft. All rights reserved. -import pytest from unittest.mock import Mock, patch +import pytest + + class TestEnableEntraAuthentication: def test_sync_authentication_function_registration(self): """Test that enable_entra_authentication registers event listener successfully.""" mock_engine = Mock() - - with patch('sqlalchemy.event.listens_for') as mock_event_listener: + + with patch("sqlalchemy.event.listens_for") as mock_event_listener: from azurepg_entra.sqlalchemy import enable_entra_authentication + enable_entra_authentication(mock_engine) - + # Verify event listener was registered with correct parameters mock_event_listener.assert_called_once_with(mock_engine, "do_connect") def test_provide_token_method(self): """Test the provide_token event handler method directly.""" mock_engine = Mock() - + # Capture the event handler function captured_handler = None + def capture_handler(engine, event_name): def decorator(func): nonlocal captured_handler captured_handler = func return func + return decorator - - with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo') as mock_get_creds: - mock_get_creds.return_value = {"user": "test@example.com", "password": "test_token"} - + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo" + ) as mock_get_creds: + mock_get_creds.return_value = { + "user": "test@example.com", + "password": "test_token", + } + from azurepg_entra.sqlalchemy import enable_entra_authentication + enable_entra_authentication(mock_engine) - + # Test the captured handler directly mock_cparams = {} captured_handler(None, None, None, mock_cparams) - + # Verify credentials were added mock_get_creds.assert_called_once_with(None) assert mock_cparams["user"] == "test@example.com" @@ -46,25 +57,33 @@ def decorator(func): def test_provide_token_skips_existing_credentials(self): """Test that provide_token skips when credentials already exist.""" mock_engine = Mock() - + # Capture the event handler function captured_handler = None + def capture_handler(engine, event_name): def decorator(func): nonlocal captured_handler captured_handler = func return func + return decorator - - with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo') as mock_get_creds: + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo" + ) as mock_get_creds: from azurepg_entra.sqlalchemy import enable_entra_authentication + enable_entra_authentication(mock_engine) - + # Test with existing credentials - mock_cparams = {"user": "existing@example.com", "password": "existing_password"} + mock_cparams = { + "user": "existing@example.com", + "password": "existing_password", + } captured_handler(None, None, None, mock_cparams) - + # Verify get_entra_conninfo was not called mock_get_creds.assert_not_called() assert mock_cparams["user"] == "existing@example.com" @@ -75,11 +94,12 @@ def test_async_authentication_function_registration(self): mock_async_engine = Mock() mock_sync_engine = Mock() mock_async_engine.sync_engine = mock_sync_engine - - with patch('sqlalchemy.event.listens_for') as mock_event_listener: + + with patch("sqlalchemy.event.listens_for") as mock_event_listener: from azurepg_entra.sqlalchemy import enable_entra_authentication_async + enable_entra_authentication_async(mock_async_engine) - + # Verify event listener was registered on sync_engine mock_event_listener.assert_called_once_with(mock_sync_engine, "do_connect") @@ -88,27 +108,35 @@ def test_provide_token_async_method(self): mock_async_engine = Mock() mock_sync_engine = Mock() mock_async_engine.sync_engine = mock_sync_engine - + # Capture the event handler function captured_handler = None + def capture_handler(engine, event_name): def decorator(func): nonlocal captured_handler captured_handler = func return func + return decorator - - with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo') as mock_get_creds: - mock_get_creds.return_value = {"user": "test@example.com", "password": "test_token"} - + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo" + ) as mock_get_creds: + mock_get_creds.return_value = { + "user": "test@example.com", + "password": "test_token", + } + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + enable_entra_authentication_async(mock_async_engine) - + # Test the captured handler directly mock_cparams = {} captured_handler(None, None, None, mock_cparams) - + # Verify credentials were added mock_get_creds.assert_called_once_with(None) assert mock_cparams["user"] == "test@example.com" @@ -119,33 +147,41 @@ def test_provide_token_async_skips_existing_credentials(self): mock_async_engine = Mock() mock_sync_engine = Mock() mock_async_engine.sync_engine = mock_sync_engine - + # Capture the event handler function captured_handler = None + def capture_handler(engine, event_name): def decorator(func): nonlocal captured_handler captured_handler = func return func + return decorator - - with patch('sqlalchemy.event.listens_for', side_effect=capture_handler): - with patch('azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo') as mock_get_creds: + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo" + ) as mock_get_creds: from azurepg_entra.sqlalchemy import enable_entra_authentication_async + enable_entra_authentication_async(mock_async_engine) - + # Test with existing credentials - mock_cparams = {"user": "existing@example.com", "password": "existing_password"} + mock_cparams = { + "user": "existing@example.com", + "password": "existing_password", + } captured_handler(None, None, None, mock_cparams) - + # Verify get_entra_conninfo was not called (credentials already exist) mock_get_creds.assert_not_called() assert mock_cparams["user"] == "existing@example.com" assert mock_cparams["password"] == "existing_password" - if __name__ == "__main__": import sys + exit_code = pytest.main([__file__, "-v", "--tb=short"]) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/python/tests/azure/data/postgresql/test_core_functionality.py b/python/tests/azure/data/postgresql/test_core_functionality.py index 3b8a4cd..bb2a554 100644 --- a/python/tests/azure/data/postgresql/test_core_functionality.py +++ b/python/tests/azure/data/postgresql/test_core_functionality.py @@ -1,28 +1,35 @@ # Copyright (c) Microsoft. All rights reserved. -import json import base64 -import pytest +import json from unittest.mock import AsyncMock, Mock, patch + +import pytest from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential -from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError from azurepg_entra.core import ( decode_jwt, - parse_principal_name, get_entra_conninfo, get_entra_conninfo_async, + parse_principal_name, ) +from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError + def create_test_token(payload): """Helper to create a test JWT token manually.""" # Create a simple JWT-like token with header.payload.signature format header = {"alg": "none", "typ": "JWT"} - header_encoded = base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip('=') - payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) signature = "" return f"{header_encoded}.{payload_encoded}.{signature}" + class TestJwtParsing: def test_decode_jwt_with_upn(self): payload = {"upn": "user@example.com"} @@ -56,8 +63,8 @@ def test_get_entra_conninfo_with_upn(self): mock_credential = Mock(spec=TokenCredential) payload = {"upn": "user@example.com"} token = create_test_token(payload) - - with patch('azurepg_entra.core.get_entra_token', return_value=token): + + with patch("azurepg_entra.core.get_entra_token", return_value=token): result = get_entra_conninfo(mock_credential) assert result == {"user": "user@example.com", "password": token} @@ -65,10 +72,13 @@ def test_get_entra_conninfo_no_username_throws(self): mock_credential = Mock(spec=TokenCredential) payload = {"sub": "subject123"} token = create_test_token(payload) - + # Mock both the DB token and the management token to have no username claims - with patch('azurepg_entra.core.get_entra_token', return_value=token): - with pytest.raises(UsernameExtractionError, match="Could not determine username from token claims"): + with patch("azurepg_entra.core.get_entra_token", return_value=token): + with pytest.raises( + UsernameExtractionError, + match="Could not determine username from token claims", + ): get_entra_conninfo(mock_credential) @pytest.mark.asyncio @@ -76,8 +86,8 @@ async def test_get_entra_conninfo_async_with_upn(self): mock_credential = AsyncMock(spec=AsyncTokenCredential) payload = {"upn": "user@example.com"} token = create_test_token(payload) - - with patch('azurepg_entra.core.get_entra_token_async', return_value=token): + + with patch("azurepg_entra.core.get_entra_token_async", return_value=token): result = await get_entra_conninfo_async(mock_credential) assert result == {"user": "user@example.com", "password": token} @@ -86,14 +96,18 @@ async def test_get_entra_conninfo_async_no_username_throws(self): mock_credential = AsyncMock(spec=AsyncTokenCredential) payload = {"sub": "subject123"} token = create_test_token(payload) - + # Mock both the DB token and the management token to have no username claims - with patch('azurepg_entra.core.get_entra_token_async', return_value=token): - with pytest.raises(UsernameExtractionError, match="Could not determine username from token claims"): + with patch("azurepg_entra.core.get_entra_token_async", return_value=token): + with pytest.raises( + UsernameExtractionError, + match="Could not determine username from token claims", + ): await get_entra_conninfo_async(mock_credential) if __name__ == "__main__": import sys + exit_code = pytest.main([__file__, "-v", "--tb=short"]) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) From f66f56d2b8cc55697db25f4eea07e0aa95b73889 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Tue, 7 Oct 2025 23:29:16 -0700 Subject: [PATCH 07/19] Add types-psycopg2 dependency for static type checking --- python/pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 09d2296..e1d4437 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -46,7 +46,8 @@ dev = [ "pytest-asyncio>=0.21.0", "python-dotenv>=1.0.0", "mypy ~= 1.15", - "ruff>=0.8.0" + "ruff>=0.8.0", + "types-psycopg2>=2.9.0" ] # All optional dependencies combined @@ -59,7 +60,8 @@ all = [ "pytest-asyncio>=0.21.0", "python-dotenv>=1.0.0", "mypy ~= 1.15", - "ruff>=0.8.0" + "ruff>=0.8.0", + "types-psycopg2>=2.9.0" ] [tool.setuptools] From 9079dec27c68750578c8d50a86c14eb33e72c092 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Tue, 7 Oct 2025 23:35:50 -0700 Subject: [PATCH 08/19] Add psycopg-pool to dev dependencies --- python/pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index e1d4437..ff34530 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -47,7 +47,8 @@ dev = [ "python-dotenv>=1.0.0", "mypy ~= 1.15", "ruff>=0.8.0", - "types-psycopg2>=2.9.0" + "types-psycopg2>=2.9.0", + "psycopg-pool>=3.1.0" ] # All optional dependencies combined @@ -61,7 +62,8 @@ all = [ "python-dotenv>=1.0.0", "mypy ~= 1.15", "ruff>=0.8.0", - "types-psycopg2>=2.9.0" + "types-psycopg2>=2.9.0", + "psycopg-pool>=3.1.0" ] [tool.setuptools] From 3967f0423cb09dfdd7b402704c786ce49bf710ef Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Wed, 8 Oct 2025 10:26:23 -0700 Subject: [PATCH 09/19] Add sample programs to test token refresh capabilities --- ...eate_everlasting_db_connection_psycopg2.py | 116 ++++++++++++ ...eate_everlasting_db_connection_psycopg3.py | 177 ++++++++++++++++++ ...te_everlasting_db_connection_sqlalchemy.py | 173 +++++++++++++++++ 3 files changed, 466 insertions(+) create mode 100644 python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py create mode 100644 python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py create mode 100644 python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py diff --git a/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py new file mode 100644 index 0000000..1819692 --- /dev/null +++ b/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py @@ -0,0 +1,116 @@ +""" +Sample demonstrating an everlasting psycopg2 connection with Azure Entra ID authentication +for Azure PostgreSQL that runs queries indefinitely to test token refresh capabilities. +""" + +import argparse +import os +import sys +import time +from datetime import datetime + +import psycopg2 +from azurepg_entra.psycopg2 import EntraConnection +from dotenv import load_dotenv +from psycopg2.pool import ThreadedConnectionPool + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def run_everlasting_queries(interval_minutes: int = 2) -> None: + """Run database queries indefinitely with psycopg2 and Entra authentication using ThreadedConnectionPool.""" + + print("=== Running Everlasting psycopg2 Connection Pool Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create connection string + conninfo = f"postgresql://{SERVER}:5432/{DATABASE}" + + # Create connection pool with EntraConnection factory + print("Creating ThreadedConnectionPool with EntraConnection factory...") + pool = ThreadedConnectionPool( + minconn=1, + maxconn=3, + dsn=conninfo, + connection_factory=EntraConnection + ) + + execution_count = 0 + + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Execution #{execution_count} at {current_time}") + + # Get connection from pool + conn = pool.getconn() + + try: + with conn.cursor() as cur: + # Query 1: Get PostgreSQL version + cur.execute("SELECT version()") + version = cur.fetchone() + print(f"Connected to PostgreSQL: {version[0][:50]}...") + + # Query 2: Get current user (shows the Entra username) + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Connected as: {user[0]}") + + # Query 3: Get current timestamp + cur.execute("SELECT now()") + timestamp = cur.fetchone() + print(f"Server time: {timestamp[0]}") + + print("Query execution successful!") + + except psycopg2.Error as e: + print(f"Database error: {e}") + except Exception as e: + print(f"Unexpected error: {e}") + finally: + # Return connection to pool + pool.putconn(conn) + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + # Close all connections in the pool + print("Closing connection pool...") + pool.closeall() + + +def main() -> None: + """Main function with command line argument parsing.""" + parser = argparse.ArgumentParser( + description="Demonstrate everlasting psycopg2 connection with Azure Entra ID authentication" + ) + parser.add_argument( + "--interval", + type=int, + default=2, + help="Query execution interval in minutes (default: 2)" + ) + args = parser.parse_args() + + # Validate environment variables + if not SERVER: + print("Error: POSTGRES_SERVER environment variable is required") + sys.exit(1) + + print(f"Target server: {SERVER}") + print(f"Target database: {DATABASE}") + print(f"Query interval: {args.interval} minutes\n") + + # Run the everlasting queries + run_everlasting_queries(args.interval) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py new file mode 100644 index 0000000..ab13057 --- /dev/null +++ b/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py @@ -0,0 +1,177 @@ +""" +Sample demonstrating everlasting synchronous and asynchronous psycopg3 connections +with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely +to test token refresh capabilities. +""" + +import argparse +import asyncio +import os +import sys +import time +from datetime import datetime + +from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection +from dotenv import load_dotenv +from psycopg_pool import AsyncConnectionPool, ConnectionPool + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: + """Run synchronous database queries indefinitely with psycopg3 and Entra authentication.""" + + print("=== Running Everlasting Synchronous psycopg3 Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create connection pool with Entra authentication + pool = ConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=3, + open=False, + connection_class=EntraConnection + ) + pool.open() + + execution_count = 0 + + try: + with pool: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Sync Execution #{execution_count} at {current_time}") + + try: + with pool.connection() as conn, conn.cursor() as cur: + # Query 1: Get PostgreSQL version + cur.execute("SELECT version()") + version = cur.fetchone() + print(f"Connected to PostgreSQL: {version[0][:50]}...") + + # Query 2: Get current user + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Connected as: {user[0] if user else 'Unknown'}") + + # Query 3: Get current timestamp + cur.execute("SELECT now()") + timestamp = cur.fetchone() + print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") + + print("Sync query execution successful!") + + except Exception as e: + print(f"Database error: {e}") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + pool.close() + + +async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: + """Run asynchronous database queries indefinitely with psycopg3 and Entra authentication.""" + + print("=== Running Everlasting Asynchronous psycopg3 Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create async connection pool with Entra authentication + pool = AsyncConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=3, + open=False, + connection_class=AsyncEntraConnection + ) + await pool.open() + + execution_count = 0 + + try: + async with pool: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Async Execution #{execution_count} at {current_time}") + + try: + async with pool.connection() as conn, conn.cursor() as cur: + # Query 1: Get PostgreSQL version + await cur.execute("SELECT version()") + version = await cur.fetchone() + print(f"Connected to PostgreSQL: {version[0][:50]}...") + + # Query 2: Get current user + await cur.execute("SELECT current_user") + user = await cur.fetchone() + print(f"Connected as: {user[0] if user else 'Unknown'}") + + # Query 3: Get current timestamp + await cur.execute("SELECT now()") + timestamp = await cur.fetchone() + print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") + + print("Async query execution successful!") + + except Exception as e: + print(f"Database error: {e}") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + await asyncio.sleep(interval_minutes * 60) + finally: + await pool.close() + + +async def main() -> None: + """Main function with command line argument parsing.""" + parser = argparse.ArgumentParser( + description="Demonstrate everlasting psycopg3 connections with Azure Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)" + ) + parser.add_argument( + "--interval", + type=int, + default=2, + help="Query execution interval in minutes (default: 2)" + ) + args = parser.parse_args() + + # Validate environment variables + if not SERVER: + print("Error: POSTGRES_SERVER environment variable is required") + sys.exit(1) + + print(f"Target server: {SERVER}") + print(f"Target database: {DATABASE}") + print(f"Query interval: {args.interval} minutes") + print(f"Mode: {args.mode}\n") + + if args.mode in ("sync", "both"): + run_everlasting_sync_queries(args.interval) + + if args.mode in ("async", "both"): + if args.mode == "both": + print("\n" + "="*60 + "\n") + await run_everlasting_async_queries(args.interval) + + +if __name__ == "__main__": + # Set Windows event loop policy for compatibility if needed + if sys.platform.startswith('win'): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main()) \ No newline at end of file diff --git a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py new file mode 100644 index 0000000..165dfba --- /dev/null +++ b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py @@ -0,0 +1,173 @@ +""" +Sample demonstrating everlasting synchronous and asynchronous SQLAlchemy connections +with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely +to test token refresh capabilities. +""" + +import argparse +import asyncio +import os +import sys +import time +from datetime import datetime + +from azurepg_entra.sqlalchemy import ( + enable_entra_authentication, + enable_entra_authentication_async, +) +from dotenv import load_dotenv +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: + """Run synchronous database queries indefinitely with SQLAlchemy and Entra authentication.""" + + print("=== Running Everlasting Synchronous SQLAlchemy Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create synchronous engine with Entra authentication + engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + enable_entra_authentication(engine) + + execution_count = 0 + + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Sync Execution #{execution_count} at {current_time}") + + try: + with engine.connect() as conn: + # Query 1: Get PostgreSQL version + result = conn.execute(text("SELECT version()")) + row = result.fetchone() + version = row[0] if row else "Unknown" + print(f"Connected to PostgreSQL: {version[:50]}...") + + # Query 2: Get current user + result = conn.execute(text("SELECT current_user")) + row = result.fetchone() + user = row[0] if row else "Unknown" + print(f"Connected as: {user}") + + # Query 3: Get current timestamp + result = conn.execute(text("SELECT now()")) + row = result.fetchone() + timestamp = row[0] if row else "Unknown" + print(f"Server time: {timestamp}") + + print("Sync query execution successful!") + + except Exception as e: + print(f"Database error: {e}") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + engine.dispose() + + +async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: + """Run asynchronous database queries indefinitely with SQLAlchemy and Entra authentication.""" + + print("=== Running Everlasting Asynchronous SQLAlchemy Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create asynchronous engine with Entra authentication + engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + enable_entra_authentication_async(engine) + + execution_count = 0 + + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Async Execution #{execution_count} at {current_time}") + + try: + async with engine.connect() as conn: + # Query 1: Get PostgreSQL version + result = await conn.execute(text("SELECT version()")) + row = result.fetchone() + version = row[0] if row else "Unknown" + print(f"Connected to PostgreSQL: {version[:50]}...") + + # Query 2: Get current user + result = await conn.execute(text("SELECT current_user")) + row = result.fetchone() + user = row[0] if row else "Unknown" + print(f"Connected as: {user}") + + # Query 3: Get current timestamp + result = await conn.execute(text("SELECT now()")) + row = result.fetchone() + timestamp = row[0] if row else "Unknown" + print(f"Server time: {timestamp}") + + print("Async query execution successful!") + + except Exception as e: + print(f"Database error: {e}") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + await asyncio.sleep(interval_minutes * 60) + finally: + await engine.dispose() + + +async def main() -> None: + """Main function with command line argument parsing.""" + parser = argparse.ArgumentParser( + description="Demonstrate everlasting SQLAlchemy connections with Azure Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)" + ) + parser.add_argument( + "--interval", + type=int, + default=2, + help="Query execution interval in minutes (default: 2)" + ) + args = parser.parse_args() + + # Validate environment variables + if not SERVER: + print("Error: POSTGRES_SERVER environment variable is required") + sys.exit(1) + + print(f"Target server: {SERVER}") + print(f"Target database: {DATABASE}") + print(f"Query interval: {args.interval} minutes") + print(f"Mode: {args.mode}\n") + + if args.mode in ("sync", "both"): + run_everlasting_sync_queries(args.interval) + + if args.mode in ("async", "both"): + if args.mode == "both": + print("\n" + "="*60 + "\n") + await run_everlasting_async_queries(args.interval) + + +if __name__ == "__main__": + # Set Windows event loop policy for compatibility if needed + if sys.platform.startswith('win'): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main()) \ No newline at end of file From 2a150e557dea2105648b90c48ce4a2c245a1050e Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Wed, 8 Oct 2025 11:09:00 -0700 Subject: [PATCH 10/19] Add ruff and mypy checks --- .../create_everlasting_db_connection_psycopg2.py | 3 ++- .../create_everlasting_db_connection_psycopg3.py | 7 ++++--- .../create_everlasting_db_connection_sqlalchemy.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py index 1819692..1972bd1 100644 --- a/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py +++ b/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py @@ -10,10 +10,11 @@ from datetime import datetime import psycopg2 -from azurepg_entra.psycopg2 import EntraConnection from dotenv import load_dotenv from psycopg2.pool import ThreadedConnectionPool +from azurepg_entra.psycopg2 import EntraConnection + # Load environment variables from .env file load_dotenv() SERVER = os.getenv("POSTGRES_SERVER") diff --git a/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py index ab13057..a9e4e49 100644 --- a/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py +++ b/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py @@ -11,10 +11,11 @@ import time from datetime import datetime -from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection from dotenv import load_dotenv from psycopg_pool import AsyncConnectionPool, ConnectionPool +from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection + # Load environment variables from .env file load_dotenv() SERVER = os.getenv("POSTGRES_SERVER") @@ -53,7 +54,7 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: # Query 1: Get PostgreSQL version cur.execute("SELECT version()") version = cur.fetchone() - print(f"Connected to PostgreSQL: {version[0][:50]}...") + print(f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}...") # Query 2: Get current user cur.execute("SELECT current_user") @@ -108,7 +109,7 @@ async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: # Query 1: Get PostgreSQL version await cur.execute("SELECT version()") version = await cur.fetchone() - print(f"Connected to PostgreSQL: {version[0][:50]}...") + print(f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}...") # Query 2: Get current user await cur.execute("SELECT current_user") diff --git a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py index 165dfba..4c38253 100644 --- a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py +++ b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py @@ -11,13 +11,14 @@ import time from datetime import datetime +from dotenv import load_dotenv +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + from azurepg_entra.sqlalchemy import ( enable_entra_authentication, enable_entra_authentication_async, ) -from dotenv import load_dotenv -from sqlalchemy import create_engine, text -from sqlalchemy.ext.asyncio import create_async_engine # Load environment variables from .env file load_dotenv() From 8cb5c5f2a02bfe0cb7baf602a09cf943edc3aed1 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Wed, 8 Oct 2025 11:13:33 -0700 Subject: [PATCH 11/19] Fixed ruff formatting issues --- ...eate_everlasting_db_connection_psycopg2.py | 43 +++++---- ...eate_everlasting_db_connection_psycopg3.py | 88 ++++++++++--------- ...te_everlasting_db_connection_sqlalchemy.py | 68 +++++++------- 3 files changed, 102 insertions(+), 97 deletions(-) diff --git a/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py index 1972bd1..9ef9f04 100644 --- a/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py +++ b/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py @@ -1,5 +1,5 @@ """ -Sample demonstrating an everlasting psycopg2 connection with Azure Entra ID authentication +Sample demonstrating an everlasting psycopg2 connection with Azure Entra ID authentication for Azure PostgreSQL that runs queries indefinitely to test token refresh capabilities. """ @@ -23,54 +23,51 @@ def run_everlasting_queries(interval_minutes: int = 2) -> None: """Run database queries indefinitely with psycopg2 and Entra authentication using ThreadedConnectionPool.""" - + print("=== Running Everlasting psycopg2 Connection Pool Example ===") print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - + # Create connection string conninfo = f"postgresql://{SERVER}:5432/{DATABASE}" - + # Create connection pool with EntraConnection factory print("Creating ThreadedConnectionPool with EntraConnection factory...") pool = ThreadedConnectionPool( - minconn=1, - maxconn=3, - dsn=conninfo, - connection_factory=EntraConnection + minconn=1, maxconn=3, dsn=conninfo, connection_factory=EntraConnection ) - + execution_count = 0 - + try: while True: execution_count += 1 current_time = datetime.now().strftime("%H:%M:%S") - + print(f"Execution #{execution_count} at {current_time}") - + # Get connection from pool conn = pool.getconn() - + try: with conn.cursor() as cur: # Query 1: Get PostgreSQL version cur.execute("SELECT version()") version = cur.fetchone() print(f"Connected to PostgreSQL: {version[0][:50]}...") - + # Query 2: Get current user (shows the Entra username) cur.execute("SELECT current_user") user = cur.fetchone() print(f"Connected as: {user[0]}") - + # Query 3: Get current timestamp cur.execute("SELECT now()") timestamp = cur.fetchone() print(f"Server time: {timestamp[0]}") - + print("Query execution successful!") - + except psycopg2.Error as e: print(f"Database error: {e}") except Exception as e: @@ -78,7 +75,7 @@ def run_everlasting_queries(interval_minutes: int = 2) -> None: finally: # Return connection to pool pool.putconn(conn) - + print(f"Waiting {interval_minutes} minutes until next execution...\n") time.sleep(interval_minutes * 60) finally: @@ -96,22 +93,22 @@ def main() -> None: "--interval", type=int, default=2, - help="Query execution interval in minutes (default: 2)" + help="Query execution interval in minutes (default: 2)", ) args = parser.parse_args() - + # Validate environment variables if not SERVER: print("Error: POSTGRES_SERVER environment variable is required") sys.exit(1) - + print(f"Target server: {SERVER}") print(f"Target database: {DATABASE}") print(f"Query interval: {args.interval} minutes\n") - + # Run the everlasting queries run_everlasting_queries(args.interval) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py index a9e4e49..e9c9e72 100644 --- a/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py +++ b/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py @@ -1,6 +1,6 @@ """ -Sample demonstrating everlasting synchronous and asynchronous psycopg3 connections -with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely +Sample demonstrating everlasting synchronous and asynchronous psycopg3 connections +with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely to test token refresh capabilities. """ @@ -24,53 +24,57 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: """Run synchronous database queries indefinitely with psycopg3 and Entra authentication.""" - + print("=== Running Everlasting Synchronous psycopg3 Connection Example ===") print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - + # Create connection pool with Entra authentication pool = ConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, max_size=3, open=False, - connection_class=EntraConnection + connection_class=EntraConnection, ) pool.open() - + execution_count = 0 - + try: with pool: while True: execution_count += 1 current_time = datetime.now().strftime("%H:%M:%S") - + print(f"Sync Execution #{execution_count} at {current_time}") - + try: with pool.connection() as conn, conn.cursor() as cur: # Query 1: Get PostgreSQL version cur.execute("SELECT version()") version = cur.fetchone() - print(f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}...") - + print( + f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." + ) + # Query 2: Get current user cur.execute("SELECT current_user") user = cur.fetchone() print(f"Connected as: {user[0] if user else 'Unknown'}") - + # Query 3: Get current timestamp cur.execute("SELECT now()") timestamp = cur.fetchone() - print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") - + print( + f"Server time: {timestamp[0] if timestamp else 'Unknown'}" + ) + print("Sync query execution successful!") - + except Exception as e: print(f"Database error: {e}") - + print(f"Waiting {interval_minutes} minutes until next execution...\n") time.sleep(interval_minutes * 60) finally: @@ -79,53 +83,57 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: """Run asynchronous database queries indefinitely with psycopg3 and Entra authentication.""" - + print("=== Running Everlasting Asynchronous psycopg3 Connection Example ===") print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - + # Create async connection pool with Entra authentication pool = AsyncConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, max_size=3, open=False, - connection_class=AsyncEntraConnection + connection_class=AsyncEntraConnection, ) await pool.open() - + execution_count = 0 - + try: async with pool: while True: execution_count += 1 current_time = datetime.now().strftime("%H:%M:%S") - + print(f"Async Execution #{execution_count} at {current_time}") - + try: async with pool.connection() as conn, conn.cursor() as cur: # Query 1: Get PostgreSQL version await cur.execute("SELECT version()") version = await cur.fetchone() - print(f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}...") - + print( + f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." + ) + # Query 2: Get current user await cur.execute("SELECT current_user") user = await cur.fetchone() print(f"Connected as: {user[0] if user else 'Unknown'}") - + # Query 3: Get current timestamp await cur.execute("SELECT now()") timestamp = await cur.fetchone() - print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") - + print( + f"Server time: {timestamp[0] if timestamp else 'Unknown'}" + ) + print("Async query execution successful!") - + except Exception as e: print(f"Database error: {e}") - + print(f"Waiting {interval_minutes} minutes until next execution...\n") await asyncio.sleep(interval_minutes * 60) finally: @@ -141,38 +149,38 @@ async def main() -> None: "--mode", choices=["sync", "async", "both"], default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" + help="Run synchronous, asynchronous, or both examples (default: both)", ) parser.add_argument( "--interval", type=int, default=2, - help="Query execution interval in minutes (default: 2)" + help="Query execution interval in minutes (default: 2)", ) args = parser.parse_args() - + # Validate environment variables if not SERVER: print("Error: POSTGRES_SERVER environment variable is required") sys.exit(1) - + print(f"Target server: {SERVER}") print(f"Target database: {DATABASE}") print(f"Query interval: {args.interval} minutes") print(f"Mode: {args.mode}\n") - + if args.mode in ("sync", "both"): run_everlasting_sync_queries(args.interval) - + if args.mode in ("async", "both"): if args.mode == "both": - print("\n" + "="*60 + "\n") + print("\n" + "=" * 60 + "\n") await run_everlasting_async_queries(args.interval) if __name__ == "__main__": # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py index 4c38253..f11bae3 100644 --- a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py +++ b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py @@ -1,6 +1,6 @@ """ -Sample demonstrating everlasting synchronous and asynchronous SQLAlchemy connections -with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely +Sample demonstrating everlasting synchronous and asynchronous SQLAlchemy connections +with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely to test token refresh capabilities. """ @@ -28,24 +28,24 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: """Run synchronous database queries indefinitely with SQLAlchemy and Entra authentication.""" - + print("=== Running Everlasting Synchronous SQLAlchemy Connection Example ===") print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - + # Create synchronous engine with Entra authentication engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") enable_entra_authentication(engine) - + execution_count = 0 - + try: while True: execution_count += 1 current_time = datetime.now().strftime("%H:%M:%S") - + print(f"Sync Execution #{execution_count} at {current_time}") - + try: with engine.connect() as conn: # Query 1: Get PostgreSQL version @@ -53,24 +53,24 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: row = result.fetchone() version = row[0] if row else "Unknown" print(f"Connected to PostgreSQL: {version[:50]}...") - + # Query 2: Get current user result = conn.execute(text("SELECT current_user")) row = result.fetchone() user = row[0] if row else "Unknown" print(f"Connected as: {user}") - + # Query 3: Get current timestamp result = conn.execute(text("SELECT now()")) row = result.fetchone() timestamp = row[0] if row else "Unknown" print(f"Server time: {timestamp}") - + print("Sync query execution successful!") - + except Exception as e: print(f"Database error: {e}") - + print(f"Waiting {interval_minutes} minutes until next execution...\n") time.sleep(interval_minutes * 60) finally: @@ -79,24 +79,24 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: """Run asynchronous database queries indefinitely with SQLAlchemy and Entra authentication.""" - + print("=== Running Everlasting Asynchronous SQLAlchemy Connection Example ===") print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - + # Create asynchronous engine with Entra authentication engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") enable_entra_authentication_async(engine) - + execution_count = 0 - + try: while True: execution_count += 1 current_time = datetime.now().strftime("%H:%M:%S") - + print(f"Async Execution #{execution_count} at {current_time}") - + try: async with engine.connect() as conn: # Query 1: Get PostgreSQL version @@ -104,24 +104,24 @@ async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: row = result.fetchone() version = row[0] if row else "Unknown" print(f"Connected to PostgreSQL: {version[:50]}...") - + # Query 2: Get current user result = await conn.execute(text("SELECT current_user")) row = result.fetchone() user = row[0] if row else "Unknown" print(f"Connected as: {user}") - + # Query 3: Get current timestamp result = await conn.execute(text("SELECT now()")) row = result.fetchone() timestamp = row[0] if row else "Unknown" print(f"Server time: {timestamp}") - + print("Async query execution successful!") - + except Exception as e: print(f"Database error: {e}") - + print(f"Waiting {interval_minutes} minutes until next execution...\n") await asyncio.sleep(interval_minutes * 60) finally: @@ -137,38 +137,38 @@ async def main() -> None: "--mode", choices=["sync", "async", "both"], default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" + help="Run synchronous, asynchronous, or both examples (default: both)", ) parser.add_argument( "--interval", type=int, default=2, - help="Query execution interval in minutes (default: 2)" + help="Query execution interval in minutes (default: 2)", ) args = parser.parse_args() - + # Validate environment variables if not SERVER: print("Error: POSTGRES_SERVER environment variable is required") sys.exit(1) - + print(f"Target server: {SERVER}") print(f"Target database: {DATABASE}") print(f"Query interval: {args.interval} minutes") print(f"Mode: {args.mode}\n") - + if args.mode in ("sync", "both"): run_everlasting_sync_queries(args.interval) - + if args.mode in ("async", "both"): if args.mode == "both": - print("\n" + "="*60 + "\n") + print("\n" + "=" * 60 + "\n") await run_everlasting_async_queries(args.interval) if __name__ == "__main__": # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) From e2a036bfb93b792451ca07f17fcdcfe59fd6f2ae Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Mon, 13 Oct 2025 10:01:02 -0700 Subject: [PATCH 12/19] Make major updates based on Rob's feedback including to GitHub Actions and exception handling --- .github/workflows/pr-dotnet.yml | 25 +++ .github/workflows/pr-python.yml | 31 +++ .github/workflows/pr.yml | 64 ------ python/README.md | 187 ++++-------------- ...on_psycopg2.py => create_db_connection.py} | 6 +- ...py => create_everlasting_db_connection.py} | 0 ...on_psycopg3.py => create_db_connection.py} | 0 ...py => create_everlasting_db_connection.py} | 0 ..._sqlalchemy.py => create_db_connection.py} | 23 ++- ...py => create_everlasting_db_connection.py} | 0 python/src/azurepg_entra/errors.py | 12 +- python/src/azurepg_entra/psycopg2/__init__.py | 20 +- .../psycopg2/entra_connection.py | 20 +- python/src/azurepg_entra/psycopg3/__init__.py | 15 +- .../psycopg3/async_entra_connection.py | 22 +-- .../psycopg3/entra_connection.py | 25 +-- .../sqlalchemy/async_entra_connection.py | 26 +-- .../sqlalchemy/entra_connection.py | 22 +-- run-dotnet-checks.ps1 | 78 ++++++++ run-python-checks.ps1 | 117 +++++++++++ 20 files changed, 386 insertions(+), 307 deletions(-) create mode 100644 .github/workflows/pr-dotnet.yml create mode 100644 .github/workflows/pr-python.yml delete mode 100644 .github/workflows/pr.yml rename python/samples/psycopg2/getting_started/{create_db_connection_psycopg2.py => create_db_connection.py} (94%) rename python/samples/psycopg2/getting_started/{create_everlasting_db_connection_psycopg2.py => create_everlasting_db_connection.py} (100%) rename python/samples/psycopg3/getting_started/{create_db_connection_psycopg3.py => create_db_connection.py} (100%) rename python/samples/psycopg3/getting_started/{create_everlasting_db_connection_psycopg3.py => create_everlasting_db_connection.py} (100%) rename python/samples/sqlalchemy/getting_started/{create_db_connection_sqlalchemy.py => create_db_connection.py} (82%) rename python/samples/sqlalchemy/getting_started/{create_everlasting_db_connection_sqlalchemy.py => create_everlasting_db_connection.py} (100%) create mode 100644 run-dotnet-checks.ps1 create mode 100644 run-python-checks.ps1 diff --git a/.github/workflows/pr-dotnet.yml b/.github/workflows/pr-dotnet.yml new file mode 100644 index 0000000..baa703c --- /dev/null +++ b/.github/workflows/pr-dotnet.yml @@ -0,0 +1,25 @@ +name: PR .NET Checks +on: + pull_request: + branches: [ main ] + paths: + - 'dotnet/**' + - '.github/workflows/pr-dotnet.yml' + +jobs: + dotnet-quality: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-dotnet@v4 + with: + dotnet-version: '9.0.x' + - name: Restore + working-directory: dotnet + run: dotnet restore + - name: Build + working-directory: dotnet + run: dotnet build --configuration Release --no-restore + - name: Test + working-directory: dotnet + run: dotnet test --configuration Release --no-build --verbosity normal \ No newline at end of file diff --git a/.github/workflows/pr-python.yml b/.github/workflows/pr-python.yml new file mode 100644 index 0000000..d91fd43 --- /dev/null +++ b/.github/workflows/pr-python.yml @@ -0,0 +1,31 @@ +name: PR Python Checks +on: + pull_request: + branches: [ main ] + paths: + - 'python/**' + - '.github/workflows/pr-python.yml' + +jobs: + python-quality: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install deps + working-directory: python + run: | + python -m pip install --upgrade pip + pip install .[all] + - name: Ruff Lint + working-directory: python + run: python -m ruff check src tests + - name: Type check + working-directory: python + run: python -m mypy src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py + - name: Tests + if: ${{ always() }} # adjust if you add tests + working-directory: python + run: python -m pytest -q \ No newline at end of file diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml deleted file mode 100644 index df37125..0000000 --- a/.github/workflows/pr.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: PR Checks - -on: - pull_request: - branches: - - main - -jobs: - python-quality-checks: - runs-on: ubuntu-latest - - steps: - # Check out the repo - - name: Checkout code - uses: actions/checkout@v4 - - # Set up Python - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - # Install dependencies - - name: Install dependencies - working-directory: ./python - run: | - python -m pip install --upgrade pip - pip install -e .[all] - - # Run Ruff linter - - name: Run Ruff linting - working-directory: ./python - run: | - ruff check . - - # Run Ruff formatter check - - name: Check Ruff formatting - working-directory: ./python - run: | - ruff format --check . - - # Run mypy on source code (strict) - - name: Run mypy on source code - working-directory: ./python - run: | - mypy --strict src/ - - # Run mypy on samples (strict) - - name: Run mypy on samples - working-directory: ./python - run: | - mypy --strict samples/ - - # Run mypy on tests (basic) - - name: Run mypy on tests - working-directory: ./python - run: | - mypy tests/ - - # Run pytest - - name: Run tests - working-directory: ./python - run: | - pytest tests/ -v --tb=short \ No newline at end of file diff --git a/python/README.md b/python/README.md index 6efe592..8414de2 100644 --- a/python/README.md +++ b/python/README.md @@ -119,58 +119,26 @@ pip install "azurepg-entra[psycopg2]" ### Connection Pooling (Recommended) ```python -from azurepg_entra.psycopg2 import EntraConnection -from psycopg2 import pool -import os - -def main(): - # Connection pooling with Entra authentication - connection_pool = pool.ThreadedConnectionPool( - minconn=1, - maxconn=5, - host="your-server.postgres.database.azure.com", - database="your_database", - connection_factory=EntraConnection - ) - - # Get a connection from the pool - conn = connection_pool.getconn() - - try: - with conn.cursor() as cur: - cur.execute("SELECT current_user, now()") - user, time = cur.fetchone() - print(f"Connected as: {user} at {time}") - finally: - # Return connection to pool - connection_pool.putconn(conn) - connection_pool.closeall() - -if __name__ == "__main__": - main() +from azurepg_entra.psycopg2 import EntraConnection # import library +from psycopg2 import pool # import to use pooling + +with pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host="your-server.postgres.database.azure.com", + database="your_database", + connection_factory=EntraConnection +) as connection_pool: ``` ### Direct Connection ```python -from azurepg_entra.psycopg2 import EntraConnection +from azurepg_entra.psycopg2 import EntraConnection # import library -def main(): - # Direct connection using DSN - conn = EntraConnection( - "postgresql://your-server.postgres.database.azure.com:5432/your_database" - ) - - try: - with conn.cursor() as cur: - cur.execute("SELECT current_user, now()") - user, time = cur.fetchone() - print(f"Connected as: {user} at {time}") - finally: - conn.close() - -if __name__ == "__main__": - main() +with EntraConnection( + "postgresql://your-server.postgres.database.azure.com:5432/your_database" +) as conn ``` --- @@ -187,67 +155,29 @@ pip install "azurepg-entra[psycopg3]" ### Synchronous Connection ```python -from azurepg_entra.psycopg3 import EntraConnection -from psycopg_pool import ConnectionPool - -def main(): - # Connection pooling (recommended for production) - pool = ConnectionPool( - conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", - connection_class=EntraConnection, - min_size=1, # keep at least 1 connection always open - max_size=5, # allow up to 5 concurrent connections - open=False - ) - - pool.open() - try: - with pool.connection() as conn: - with conn.cursor() as cur: - cur.execute("SELECT current_user, now()") - user, time = cur.fetchone() - print(f"Connected as: {user} at {time}") - finally: - pool.close() - -if __name__ == "__main__": - main() +from azurepg_entra.psycopg3 import EntraConnection # import library +from psycopg_pool import ConnectionPool # import to use pooling + +with ConnectionPool( + conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", + connection_class=EntraConnection, + min_size=1, # keep at least 1 connection always open + max_size=5, # allow up to 5 concurrent connections +) as pool ``` ### Asynchronous Connection ```python -import asyncio -import sys -from azurepg_entra.psycopg3 import AsyncEntraConnection -from psycopg_pool import AsyncConnectionPool - -async def main(): - # Async connection pooling (recommended for production) - pool = AsyncConnectionPool( - conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", - connection_class=AsyncEntraConnection, - min_size=1, # keep at least 1 connection always open - max_size=5, # allow up to 5 concurrent connections - open=False - ) - - await pool.open() - try: - async with pool.connection() as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT current_user, now()") - user, time = await cur.fetchone() - print(f"Async connected as: {user} at {time}") - finally: - await pool.close() - -if __name__ == "__main__": - # Windows compatibility for async operations - if sys.platform.startswith('win'): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main()) +from azurepg_entra.psycopg3 import AsyncEntraConnection # import library +from psycopg_pool import AsyncConnectionPool # import to use pooling + +async with AsyncConnectionPool( + conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", + connection_class=AsyncEntraConnection, + min_size=1, # keep at least 1 connection always open + max_size=5, # allow up to 5 concurrent connections +) as pool ``` --- @@ -256,6 +186,8 @@ if __name__ == "__main__": SQLAlchemy integration uses psycopg3 as the backend driver with automatic Entra ID authentication through event listeners. +> **For more information**: See SQLAlchemy's documentation on [controlling how parameters are passed to the DBAPI connect function](https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function). + ### Installation ```bash pip install "azurepg-entra[sqlalchemy]" @@ -264,76 +196,37 @@ pip install "azurepg-entra[sqlalchemy]" ### Synchronous Engine ```python -from sqlalchemy import create_engine, text -from azurepg_entra.sqlalchemy import enable_entra_authentication +from sqlalchemy import create_engine +from azurepg_entra.sqlalchemy import enable_entra_authentication # import library -def main(): - # Create synchronous engine - engine = create_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") - +with create_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") as engine: # Enable Entra ID authentication enable_entra_authentication(engine) # Core usage with engine.connect() as conn: - result = conn.execute(text("SELECT current_user, now()")) - user, time = result.fetchone() - print(f"SQLAlchemy connected as: {user} at {time}") - + # ORM usage from sqlalchemy.orm import sessionmaker Session = sessionmaker(bind=engine) - - with Session() as session: - result = session.execute(text("SELECT current_database()")) - db_name = result.scalar() - print(f"Connected to database: {db_name}") - - engine.dispose() - -if __name__ == "__main__": - main() ``` ### Asynchronous Engine ```python -import asyncio -import sys from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy import text -from azurepg_entra.sqlalchemy import enable_entra_authentication_async +from azurepg_entra.sqlalchemy import enable_entra_authentication_async # import library -async def main(): - # Create asynchronous engine - engine = create_async_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") - +async with create_async_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") as engine: # Enable Entra ID authentication for async enable_entra_authentication_async(engine) # Async Core usage async with engine.connect() as conn: - result = await conn.execute(text("SELECT current_user, now()")) - user, time = result.fetchone() - print(f"Async SQLAlchemy connected as: {user} at {time}") # Async ORM usage from sqlalchemy.ext.asyncio import async_sessionmaker AsyncSession = async_sessionmaker(engine, expire_on_commit=False) - - async with AsyncSession() as session: - result = await session.execute(text("SELECT current_database()")) - db_name = result.scalar() - print(f"Async connected to database: {db_name}") - - await engine.dispose() - -if __name__ == "__main__": - # Windows compatibility for async operations - if sys.platform.startswith('win'): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main()) ``` ## How It Works diff --git a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_db_connection.py similarity index 94% rename from python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py rename to python/samples/psycopg2/getting_started/create_db_connection.py index 1b79300..f979101 100644 --- a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py +++ b/python/samples/psycopg2/getting_started/create_db_connection.py @@ -15,7 +15,7 @@ DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") -def main_sync() -> None: +def main() -> None: try: # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. # This class is applied whenever the connection pool creates a new connection, ensuring that Entra @@ -50,9 +50,9 @@ def main_sync() -> None: connection_pool.closeall() except Exception as e: - print(f"Sync - Error connecting to database: {e}") + print(f"Error connecting to database: {e}") raise if __name__ == "__main__": - main_sync() + main() diff --git a/python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py similarity index 100% rename from python/samples/psycopg2/getting_started/create_everlasting_db_connection_psycopg2.py rename to python/samples/psycopg2/getting_started/create_everlasting_db_connection.py diff --git a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_db_connection.py similarity index 100% rename from python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py rename to python/samples/psycopg3/getting_started/create_db_connection.py diff --git a/python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py similarity index 100% rename from python/samples/psycopg3/getting_started/create_everlasting_db_connection_psycopg3.py rename to python/samples/psycopg3/getting_started/create_everlasting_db_connection.py diff --git a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_db_connection.py similarity index 82% rename from python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py rename to python/samples/sqlalchemy/getting_started/create_db_connection.py index f623bb3..b2dab5e 100644 --- a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py +++ b/python/samples/sqlalchemy/getting_started/create_db_connection.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv from sqlalchemy import create_engine, text -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from azurepg_entra.sqlalchemy import ( enable_entra_authentication, @@ -71,16 +71,33 @@ async def main_async() -> None: # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function enable_entra_authentication_async(engine) + # Core usage example async with engine.connect() as conn: # Query 1 result = await conn.execute(text("SELECT now()")) row = result.fetchone() - print(f"Async - Database time: {row[0] if row else 'Unknown'}") + print(f"Async Core - Database time: {row[0] if row else 'Unknown'}") # Query 2 result = await conn.execute(text("SELECT current_user")) row = result.fetchone() - print(f"Async - Connected as: {row[0] if row else 'Unknown'}") + print(f"Async Core - Connected as: {row[0] if row else 'Unknown'}") + + # ORM usage example with async_sessionmaker + AsyncSession = async_sessionmaker(engine, expire_on_commit=False) + + async with AsyncSession() as session: + # Query 1 + result = await session.execute(text("SELECT current_database()")) + db_name = result.scalar() + print(f"Async ORM - Connected to database: {db_name}") + + # Query 2 + result = await session.execute(text("SELECT version()")) + version = result.scalar() + # Just show the first part of the version string for cleaner output + version_short = version.split(' on ')[0] if version else 'Unknown' + print(f"Async ORM - PostgreSQL version: {version_short}") # Clean up the engine await engine.dispose() diff --git a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py similarity index 100% rename from python/samples/sqlalchemy/getting_started/create_everlasting_db_connection_sqlalchemy.py rename to python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py diff --git a/python/src/azurepg_entra/errors.py b/python/src/azurepg_entra/errors.py index 6249a29..4478222 100644 --- a/python/src/azurepg_entra/errors.py +++ b/python/src/azurepg_entra/errors.py @@ -1,34 +1,34 @@ -class EntraIdBaseError(Exception): +class AzurePgEntraError(Exception): """Base class for all custom exceptions in the project.""" pass -class TokenDecodeError(EntraIdBaseError): +class TokenDecodeError(AzurePgEntraError): """Raised when a token value is invalid.""" pass -class UsernameExtractionError(EntraIdBaseError): +class UsernameExtractionError(AzurePgEntraError): """Raised when username cannot be extracted from token.""" pass -class CredentialValueError(EntraIdBaseError): +class CredentialValueError(AzurePgEntraError): """Raised when token credential is invalid.""" pass -class EntraConnectionValueError(EntraIdBaseError): +class EntraConnectionValueError(AzurePgEntraError): """Raised when Entra connection credentials are invalid.""" pass -class ScopePermissionError(EntraIdBaseError): +class ScopePermissionError(AzurePgEntraError): """Raised when the provided scope does not have sufficient permissions.""" pass diff --git a/python/src/azurepg_entra/psycopg2/__init__.py b/python/src/azurepg_entra/psycopg2/__init__.py index ec25784..13339c1 100644 --- a/python/src/azurepg_entra/psycopg2/__init__.py +++ b/python/src/azurepg_entra/psycopg2/__init__.py @@ -27,18 +27,10 @@ ) """ -try: - from .entra_connection import ( - EntraConnection, - ) - - __all__ = [ - "EntraConnection", - ] +from .entra_connection import ( + EntraConnection, +) -except ImportError as e: - # Provide a helpful error message if psycopg2 dependencies are missing - raise ImportError( - "psycopg2 dependencies are not installed. " - "Install them with: pip install azurepg-entra[psycopg2]" - ) from e +__all__ = [ + "EntraConnection", +] diff --git a/python/src/azurepg_entra/psycopg2/entra_connection.py b/python/src/azurepg_entra/psycopg2/entra_connection.py index 82722f7..d7f9fbd 100644 --- a/python/src/azurepg_entra/psycopg2/entra_connection.py +++ b/python/src/azurepg_entra/psycopg2/entra_connection.py @@ -2,15 +2,20 @@ from typing import Any from azure.core.credentials import TokenCredential -from psycopg2.extensions import connection, make_dsn, parse_dsn + +try: + from psycopg2.extensions import connection, make_dsn, parse_dsn +except ImportError as e: + # Provide a helpful error message if psycopg2 dependencies are missing + raise ImportError( + "psycopg2 dependencies are not installed. " + "Install them with: pip install azurepg-entra[psycopg2]" + ) from e from azurepg_entra.core import get_entra_conninfo from azurepg_entra.errors import ( CredentialValueError, EntraConnectionValueError, - ScopePermissionError, - TokenDecodeError, - UsernameExtractionError, ) @@ -52,12 +57,7 @@ def __init__(self, dsn: str, **kwargs: Any) -> None: if not has_user or not has_password: try: entra_creds = get_entra_conninfo(credential) - except ( - TokenDecodeError, - UsernameExtractionError, - ScopePermissionError, - ) as e: - print(repr(e)) + except (Exception) as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/src/azurepg_entra/psycopg3/__init__.py b/python/src/azurepg_entra/psycopg3/__init__.py index 879889c..5d13eb9 100644 --- a/python/src/azurepg_entra/psycopg3/__init__.py +++ b/python/src/azurepg_entra/psycopg3/__init__.py @@ -32,14 +32,7 @@ ) """ -try: - from .async_entra_connection import AsyncEntraConnection - from .entra_connection import EntraConnection - - __all__ = ["EntraConnection", "AsyncEntraConnection"] -except ImportError as e: - # Provide a helpful error message if psycopg dependencies are missing - raise ImportError( - "psycopg3 dependencies are not installed. " - "Install them with: pip install azurepg-entra[psycopg3]" - ) from e +from .async_entra_connection import AsyncEntraConnection +from .entra_connection import EntraConnection + +__all__ = ["EntraConnection", "AsyncEntraConnection"] diff --git a/python/src/azurepg_entra/psycopg3/async_entra_connection.py b/python/src/azurepg_entra/psycopg3/async_entra_connection.py index 2d1b5f2..5eaaefc 100644 --- a/python/src/azurepg_entra/psycopg3/async_entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/async_entra_connection.py @@ -3,15 +3,20 @@ from typing import Any from azure.core.credentials_async import AsyncTokenCredential -from psycopg import AsyncConnection + +try: + from psycopg import AsyncConnection +except ImportError as e: + # Provide a helpful error message if psycopg3 dependencies are missing + raise ImportError( + "psycopg3 dependencies are not installed. " + "Install them with: pip install azurepg-entra[psycopg3]" + ) from e from azurepg_entra.core import get_entra_conninfo_async from azurepg_entra.errors import ( CredentialValueError, EntraConnectionValueError, - ScopePermissionError, - TokenDecodeError, - UsernameExtractionError, ) try: @@ -24,7 +29,7 @@ class AsyncEntraConnection(AsyncConnection[tuple[Any, ...]]): """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" @classmethod - async def connect(cls, *args: Any, **kwargs: Any) -> Self: + async def connect(cls, *args: Any, **kwargs: Any) -> "AsyncEntraConnection": """Establishes an asynchronous PostgreSQL connection using Entra authentication. This method automatically acquires Azure Entra ID credentials when user or password @@ -55,12 +60,7 @@ async def connect(cls, *args: Any, **kwargs: Any) -> Self: if not kwargs.get("user") or not kwargs.get("password"): try: entra_conninfo = await get_entra_conninfo_async(credential) - except ( - TokenDecodeError, - UsernameExtractionError, - ScopePermissionError, - ) as e: - print(repr(e)) + except (Exception) as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/src/azurepg_entra/psycopg3/entra_connection.py b/python/src/azurepg_entra/psycopg3/entra_connection.py index 3246938..3da72e6 100644 --- a/python/src/azurepg_entra/psycopg3/entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/entra_connection.py @@ -1,21 +1,21 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any +from azure.core.credentials import TokenCredential try: - from typing import Self -except ImportError: - from typing_extensions import Self # fallback for older Python -from azure.core.credentials import TokenCredential -from psycopg import Connection + from psycopg import Connection +except ImportError as e: + # Provide a helpful error message if psycopg3 dependencies are missing + raise ImportError( + "psycopg3 dependencies are not installed. " + "Install them with: pip install azurepg-entra[psycopg3]" + ) from e from azurepg_entra.core import get_entra_conninfo from azurepg_entra.errors import ( CredentialValueError, EntraConnectionValueError, - ScopePermissionError, - TokenDecodeError, - UsernameExtractionError, ) @@ -23,7 +23,7 @@ class EntraConnection(Connection[tuple[Any, ...]]): """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" @classmethod - def connect(cls, *args: Any, **kwargs: Any) -> Self: + def connect(cls, *args: Any, **kwargs: Any) -> "EntraConnection": """Establishes a synchronous PostgreSQL connection using Entra authentication. This method automatically acquires Azure Entra ID credentials when user or password @@ -54,12 +54,7 @@ def connect(cls, *args: Any, **kwargs: Any) -> Self: if not kwargs.get("user") or not kwargs.get("password"): try: entra_conninfo = get_entra_conninfo(credential) - except ( - TokenDecodeError, - UsernameExtractionError, - ScopePermissionError, - ) as e: - print(repr(e)) + except (Exception) as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py index 7b79a31..20f63b8 100644 --- a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py @@ -2,17 +2,22 @@ from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential -from sqlalchemy import event -from sqlalchemy.engine import Dialect -from sqlalchemy.ext.asyncio import AsyncEngine + +try: + from sqlalchemy import event + from sqlalchemy.engine import Dialect + from sqlalchemy.ext.asyncio import AsyncEngine +except ImportError as e: + # Provide a helpful error message if SQLAlchemy dependencies are missing + raise ImportError( + "SQLAlchemy dependencies are not installed. " + "Install them with: pip install azurepg-entra[sqlalchemy]" + ) from e from azurepg_entra.core import get_entra_conninfo from azurepg_entra.errors import ( CredentialValueError, EntraConnectionValueError, - ScopePermissionError, - TokenDecodeError, - UsernameExtractionError, ) @@ -22,6 +27,8 @@ def enable_entra_authentication_async(engine: AsyncEngine) -> None: This function registers an event listener that automatically provides Entra ID credentials for each database connection if they are not already set. + Event handlers do not support async behavior so the token fetching will still + be synchronous. Args: engine: The async SQLAlchemy Engine to enable Entra authentication for @@ -58,12 +65,7 @@ def provide_token_async( else None ) entra_creds = get_entra_conninfo(sync_credential) - except ( - TokenDecodeError, - UsernameExtractionError, - ScopePermissionError, - ) as e: - print(repr(e)) + except (Exception) as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/src/azurepg_entra/sqlalchemy/entra_connection.py b/python/src/azurepg_entra/sqlalchemy/entra_connection.py index 365fd48..972ab3f 100644 --- a/python/src/azurepg_entra/sqlalchemy/entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/entra_connection.py @@ -1,16 +1,21 @@ from typing import Any from azure.core.credentials import TokenCredential -from sqlalchemy import Engine, event -from sqlalchemy.engine import Dialect + +try: + from sqlalchemy import Engine, event + from sqlalchemy.engine import Dialect +except ImportError as e: + # Provide a helpful error message if SQLAlchemy dependencies are missing + raise ImportError( + "SQLAlchemy dependencies are not installed. " + "Install them with: pip install azurepg-entra[sqlalchemy]" + ) from e from azurepg_entra.core import get_entra_conninfo from azurepg_entra.errors import ( CredentialValueError, EntraConnectionValueError, - ScopePermissionError, - TokenDecodeError, - UsernameExtractionError, ) @@ -48,12 +53,7 @@ def provide_token( if not has_user or not has_password: try: entra_creds = get_entra_conninfo(credential) - except ( - TokenDecodeError, - UsernameExtractionError, - ScopePermissionError, - ) as e: - print(repr(e)) + except (Exception) as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/run-dotnet-checks.ps1 b/run-dotnet-checks.ps1 new file mode 100644 index 0000000..dda71f8 --- /dev/null +++ b/run-dotnet-checks.ps1 @@ -0,0 +1,78 @@ +param( + [string]$Configuration = "Release", + [switch]$Verbose, + [switch]$Help +) + +<# +.SYNOPSIS + Run .NET quality checks (restore, build, test) locally. + +.DESCRIPTION + Mirrors the CI steps in pr-dotnet.yml for confidence before pushing. + +.PARAMETER Configuration + Build configuration (Release or Debug). Defaults to Release to match CI. + +.PARAMETER Verbose + Show full command output (otherwise minimal). + +.EXAMPLE + ./run-dotnet-checks.ps1 + +.EXAMPLE + ./run-dotnet-checks.ps1 -Configuration Debug +#> + +if ($Help) { + Get-Help -Detailed -ErrorAction SilentlyContinue + Write-Host "Usage: ./run-dotnet-checks.ps1 [-Configuration Release|Debug] [-Verbose]" -ForegroundColor Cyan + exit 0 +} + +function Write-CheckResult { + param([string]$Name, [bool]$Success, [string]$Message = "") + if ($Success) { + Write-Host "PASS $Name" -ForegroundColor Green + if ($Message) { Write-Host " $Message" -ForegroundColor Gray } + } else { + Write-Host "FAIL $Name" -ForegroundColor Red + if ($Message) { Write-Host " $Message" -ForegroundColor Gray } + $script:OverallSuccess = $false + } +} + +$OverallSuccess = $true +$dotnetRoot = Join-Path (Get-Location) "dotnet" +if (-not (Test-Path $dotnetRoot)) { + Write-Host "dotnet/ directory not found" -ForegroundColor Red + exit 1 +} + +Push-Location $dotnetRoot +try { + if (-not (Get-Command dotnet -ErrorAction SilentlyContinue)) { + Write-Host ".NET SDK not found in PATH" -ForegroundColor Red + exit 1 + } + + $sdkVersion = dotnet --version + Write-Host "Using .NET SDK $sdkVersion (Configuration=$Configuration)" + + Write-Host "Restoring packages" -ForegroundColor Blue + if ($Verbose) { dotnet restore } else { dotnet restore --verbosity minimal } + Write-CheckResult "restore" ($LASTEXITCODE -eq 0) + + Write-Host "Building solution" -ForegroundColor Blue + if ($Verbose) { dotnet build --no-restore --configuration $Configuration } else { dotnet build --no-restore --configuration $Configuration --verbosity minimal } + Write-CheckResult "build" ($LASTEXITCODE -eq 0) + + Write-Host "Running tests" -ForegroundColor Blue + if ($Verbose) { dotnet test --no-build --configuration $Configuration --verbosity normal } else { dotnet test --no-build --configuration $Configuration --verbosity minimal } + Write-CheckResult "test" ($LASTEXITCODE -eq 0) +} +finally { + Pop-Location +} + +if (-not $OverallSuccess) { exit 1 } diff --git a/run-python-checks.ps1 b/run-python-checks.ps1 new file mode 100644 index 0000000..1f443c9 --- /dev/null +++ b/run-python-checks.ps1 @@ -0,0 +1,117 @@ +param( + [switch]$Verbose, + [switch]$Help, + [switch]$RecreateVenv +) + +<# +.SYNOPSIS + Run Python quality checks (lint, type, tests) locally. + +.DESCRIPTION + Mirrors the CI steps defined in pr-python.yml: install deps (.[all]), Ruff lint, mypy (target + package), pytest, and import validation. + +.PARAMETER Verbose + Show full tool output instead of suppressing it. + +.PARAMETER RecreateVenv + Delete and recreate the .venv before installing dependencies. + +.EXAMPLE + ./run-python-checks.ps1 + +.EXAMPLE + ./run-python-checks.ps1 -Verbose + +.EXAMPLE + ./run-python-checks.ps1 -RecreateVenv +#> + +if ($Help) { + Get-Help -Detailed -ErrorAction SilentlyContinue + Write-Host "Usage: ./run-python-checks.ps1 [-Verbose] [-RecreateVenv]" -ForegroundColor Cyan + exit 0 +} + +function Write-CheckResult { + param([string]$Name, [bool]$Success, [string]$Message = "") + if ($Success) { + Write-Host "PASS $Name" -ForegroundColor Green + if ($Message) { Write-Host " $Message" -ForegroundColor Gray } + } else { + Write-Host "FAIL $Name" -ForegroundColor Red + if ($Message) { Write-Host " $Message" -ForegroundColor Gray } + $script:OverallSuccess = $false + } +} + +$OverallSuccess = $true +$pythonRoot = Join-Path (Get-Location) "python" +if (-not (Test-Path $pythonRoot)) { + Write-Host "python/ directory not found" -ForegroundColor Red + exit 1 +} + +$venvPath = Join-Path $pythonRoot ".venv" +$venvPython = Join-Path $venvPath "Scripts/python.exe" + +Push-Location $pythonRoot +try { + # Use explicit parentheses so PowerShell doesn't mis-bind -and as a parameter to Test-Path on some shells + if ((Test-Path $venvPath) -and $RecreateVenv) { + Write-Host "Recreating virtual environment..." -ForegroundColor Blue + Remove-Item -Recurse -Force $venvPath + } + if (-not (Test-Path $venvPath)) { + Write-Host "Creating virtual environment" -ForegroundColor Blue + python -m venv $venvPath + if ($LASTEXITCODE -ne 0) { Write-CheckResult "Prepare venv" $false "python -m venv failed"; exit 1 } + Write-CheckResult "Prepare venv" $true "Created new venv" + } else { + Write-CheckResult "Prepare venv" $true "Reused existing venv" + } + + if (-not (Test-Path $venvPython)) { Write-CheckResult "Venv Python present" $false "Not found"; exit 1 } + $pythonVersion = & $venvPython --version 2>&1 + Write-Host "Using $pythonVersion" + + & $venvPython -m pip install --upgrade pip | Out-Null + Write-CheckResult "pip upgrade" ($LASTEXITCODE -eq 0) + + Write-Host "Installing project deps (editable .[all])" -ForegroundColor Blue + & $venvPython -m pip install -e .[all] | Out-Null + $depsOk = $LASTEXITCODE -eq 0 + if (-not $depsOk) { Write-CheckResult "Install deps" $false; exit 1 } else { Write-CheckResult "Install deps" $true } + + & $venvPython -m pip install types-psycopg2 aiohttp | Out-Null + Write-CheckResult "Extra deps (types-psycopg2,aiohttp)" ($LASTEXITCODE -eq 0) + + # Ruff + Write-Host "Running Ruff lint" -ForegroundColor Blue + if ($Verbose) { & $venvPython -m ruff check ./src ./tests } else { & $venvPython -m ruff check ./src ./tests *> $null } + Write-CheckResult "ruff lint" ($LASTEXITCODE -eq 0) + + # mypy target + if ($Verbose) { & $venvPython -m mypy ./src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py } else { & $venvPython -m mypy ./src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py *> $null } + Write-CheckResult "mypy (target)" ($LASTEXITCODE -eq 0) + + # mypy all + if ($Verbose) { & $venvPython -m mypy ./src/azurepg_entra/ } else { & $venvPython -m mypy ./src/azurepg_entra/ *> $null } + Write-CheckResult "mypy (all)" ($LASTEXITCODE -eq 0) + + if (Test-Path "tests") { + Write-Host "Running pytest" -ForegroundColor Blue + if ($Verbose) { & $venvPython -m pytest tests -v } else { & $venvPython -m pytest tests -q *> $null } + Write-CheckResult "pytest" ($LASTEXITCODE -eq 0) + } else { + Write-Host "WARN No tests directory present" -ForegroundColor Yellow + } + + & $venvPython -c "import sys; sys.path.insert(0, 'src'); import azurepg_entra, azurepg_entra.core" 2>$null + Write-CheckResult "Import validation" ($LASTEXITCODE -eq 0) +} +finally { + Pop-Location +} + +if (-not $OverallSuccess) { exit 1 } From 0f0d7965c3260034a5190890923f0d2ae45ad2a9 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Mon, 13 Oct 2025 20:52:26 -0700 Subject: [PATCH 13/19] Minor modifications such as file renaming and sample program context management --- .../getting_started/create_db_connection.py | 61 +++---- .../create_everlasting_db_connection.py | 28 +-- .../getting_started/create_db_connection.py | 109 ++++++------ .../create_everlasting_db_connection.py | 159 ++++++++---------- .../getting_started/create_db_connection.py | 123 ++++++-------- .../create_everlasting_db_connection.py | 72 +++++--- python/src/azurepg_entra/__init__.py | 1 + python/src/azurepg_entra/core.py | 2 + python/src/azurepg_entra/errors.py | 2 + python/src/azurepg_entra/psycopg2/__init__.py | 21 +-- .../psycopg2/entra_connection.py | 14 +- python/src/azurepg_entra/psycopg3/__init__.py | 20 +-- .../psycopg3/async_entra_connection.py | 11 +- .../psycopg3/entra_connection.py | 5 +- .../src/azurepg_entra/sqlalchemy/__init__.py | 23 +-- .../sqlalchemy/async_entra_connection.py | 24 +-- .../sqlalchemy/entra_connection.py | 6 +- ...xtension.py => test_entra_id_extension.py} | 21 ++- ...xtension.py => test_entra_id_extension.py} | 1 + ...xtension.py => test_entra_id_extension.py} | 2 +- .../postgresql/test_core_functionality.py | 1 + 21 files changed, 303 insertions(+), 403 deletions(-) rename python/tests/azure/data/postgresql/psycopg2/{test_psycopg2_entra_id_extension.py => test_entra_id_extension.py} (73%) rename python/tests/azure/data/postgresql/psycopg3/{test_psycopg3_entra_id_extension.py => test_entra_id_extension.py} (99%) rename python/tests/azure/data/postgresql/sqlalchemy/{test_sqlalchemy_entra_id_extension.py => test_entra_id_extension.py} (100%) diff --git a/python/samples/psycopg2/getting_started/create_db_connection.py b/python/samples/psycopg2/getting_started/create_db_connection.py index f979101..803a4ce 100644 --- a/python/samples/psycopg2/getting_started/create_db_connection.py +++ b/python/samples/psycopg2/getting_started/create_db_connection.py @@ -6,7 +6,6 @@ from dotenv import load_dotenv from psycopg2 import pool - from azurepg_entra.psycopg2 import EntraConnection # Load environment variables from .env file @@ -16,42 +15,32 @@ def main() -> None: + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/docs/advanced.html#subclassing-connection + connection_pool = pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=SERVER, + database=DATABASE, + connection_factory=EntraConnection, + ) + + conn = connection_pool.getconn() try: - # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. - # This class is applied whenever the connection pool creates a new connection, ensuring that Entra - # authentication tokens are properly managed and refreshed so that each connection uses a valid token. - # - # For more details, see: https://www.psycopg.org/docs/advanced.html#subclassing-connection - connection_pool = pool.ThreadedConnectionPool( - minconn=1, - maxconn=5, - host=SERVER, - database=DATABASE, - connection_factory=EntraConnection, - ) - - # Get a connection from the pool - conn = connection_pool.getconn() - - try: - with conn.cursor() as cur: - # Query 1 - cur.execute("SELECT now()") - result = cur.fetchone() - print(f"Database time: {result[0]}") - - # Query 2 - cur.execute("SELECT current_user") - user = cur.fetchone() - print(f"Connected as: {user[0]}") - finally: - # Return connection to pool - connection_pool.putconn(conn) - connection_pool.closeall() - - except Exception as e: - print(f"Error connecting to database: {e}") - raise + with conn.cursor() as cur: + cur.execute("SELECT now()") + result = cur.fetchone() + print(f"Database time: {result[0]}") + + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Connected as: {user[0]}") + finally: + connection_pool.putconn(conn) + connection_pool.closeall() if __name__ == "__main__": diff --git a/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py b/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py index 9ef9f04..13c85eb 100644 --- a/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py +++ b/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py @@ -8,11 +8,8 @@ import sys import time from datetime import datetime - -import psycopg2 from dotenv import load_dotenv from psycopg2.pool import ThreadedConnectionPool - from azurepg_entra.psycopg2 import EntraConnection # Load environment variables from .env file @@ -24,7 +21,6 @@ def run_everlasting_queries(interval_minutes: int = 2) -> None: """Run database queries indefinitely with psycopg2 and Entra authentication using ThreadedConnectionPool.""" - print("=== Running Everlasting psycopg2 Connection Pool Example ===") print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") @@ -32,23 +28,22 @@ def run_everlasting_queries(interval_minutes: int = 2) -> None: conninfo = f"postgresql://{SERVER}:5432/{DATABASE}" # Create connection pool with EntraConnection factory - print("Creating ThreadedConnectionPool with EntraConnection factory...") + print("Creating connection pool with Entra ID authentication enabled...") pool = ThreadedConnectionPool( minconn=1, maxconn=3, dsn=conninfo, connection_factory=EntraConnection ) execution_count = 0 + # Get one connection and reuse it throughout the program + conn = pool.getconn() + try: while True: execution_count += 1 current_time = datetime.now().strftime("%H:%M:%S") - print(f"Execution #{execution_count} at {current_time}") - # Get connection from pool - conn = pool.getconn() - try: with conn.cursor() as cur: # Query 1: Get PostgreSQL version @@ -68,19 +63,13 @@ def run_everlasting_queries(interval_minutes: int = 2) -> None: print("Query execution successful!") - except psycopg2.Error as e: - print(f"Database error: {e}") except Exception as e: - print(f"Unexpected error: {e}") - finally: - # Return connection to pool - pool.putconn(conn) + print(f"Database error: {e}") print(f"Waiting {interval_minutes} minutes until next execution...\n") time.sleep(interval_minutes * 60) finally: - # Close all connections in the pool - print("Closing connection pool...") + pool.putconn(conn) pool.closeall() @@ -101,11 +90,6 @@ def main() -> None: if not SERVER: print("Error: POSTGRES_SERVER environment variable is required") sys.exit(1) - - print(f"Target server: {SERVER}") - print(f"Target database: {DATABASE}") - print(f"Query interval: {args.interval} minutes\n") - # Run the everlasting queries run_everlasting_queries(args.interval) diff --git a/python/samples/psycopg3/getting_started/create_db_connection.py b/python/samples/psycopg3/getting_started/create_db_connection.py index eebf337..4ca0721 100644 --- a/python/samples/psycopg3/getting_started/create_db_connection.py +++ b/python/samples/psycopg3/getting_started/create_db_connection.py @@ -1,5 +1,5 @@ """ -Sample demonstrating both synchronous and asynchronous psycopg connections +Sample demonstrating both synchronous and asynchronous psycopg3 connections with Azure Entra ID authentication for Azure PostgreSQL. """ @@ -10,7 +10,6 @@ from dotenv import load_dotenv from psycopg_pool import AsyncConnectionPool, ConnectionPool - from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection # Load environment variables from .env file @@ -22,65 +21,55 @@ def main_sync() -> None: """Synchronous connection example using psycopg with Entra ID authentication.""" - try: - # We use the SyncEntraConnection class to enable synchronous Entra-based authentication for database access. - # This class is applied whenever the connection pool creates a new connection, ensuring that Entra - # authentication tokens are properly managed and refreshed so that each connection uses a valid token. - # - # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect - pool = ConnectionPool( - conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", - min_size=1, - max_size=5, - open=False, - connection_class=EntraConnection, - ) - pool.open() - with pool, pool.connection() as conn, conn.cursor() as cur: - # Query 1 - cur.execute("SELECT now()") - result = cur.fetchone() - print(f"Sync - Database time: {result}") - - # Query 2 - cur.execute("SELECT current_user") - user = cur.fetchone() - print(f"Sync - Connected as: {user[0] if user else 'Unknown'}") - except Exception as e: - print(f"Sync - Error connecting to database: {e}") - raise + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = ConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=5, + open=False, + connection_class=EntraConnection, + ) + with pool, pool.connection() as conn, conn.cursor() as cur: + # Query 1 + cur.execute("SELECT now()") + result = cur.fetchone() + print(f"Sync - Database time: {result}") + + # Query 2 + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Sync - Connected as: {user[0] if user else 'Unknown'}") async def main_async() -> None: """Asynchronous connection example using psycopg with Entra ID authentication.""" - try: - # We use the AsyncEntraConnection class to enable asynchronous Entra-based authentication for database access. - # This class is applied whenever the connection pool creates a new connection, ensuring that Entra - # authentication tokens are properly managed and refreshed so that each connection uses a valid token. - # - # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect - pool = AsyncConnectionPool( - conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", - min_size=1, - max_size=5, - open=False, - connection_class=AsyncEntraConnection, - ) - await pool.open() - async with pool, pool.connection() as conn, conn.cursor() as cur: - # Query 1 - await cur.execute("SELECT now()") - result = await cur.fetchone() - print(f"Async - Database time: {result}") - - # Query 2 - await cur.execute("SELECT current_user") - user = await cur.fetchone() - print(f"Async - Connected as: {user[0] if user else 'Unknown'}") - except Exception as e: - print(f"Async - Error connecting to database: {e}") - raise + # We use the AsyncEntraConnection class to enable asynchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = AsyncConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=5, + open=False, + connection_class=AsyncEntraConnection, + ) + async with pool, pool.connection() as conn, conn.cursor() as cur: + # Query 1 + await cur.execute("SELECT now()") + result = await cur.fetchone() + print(f"Async - Database time: {result}") + + # Query 2 + await cur.execute("SELECT current_user") + user = await cur.fetchone() + print(f"Async - Connected as: {user[0] if user else 'Unknown'}") async def main(mode: str = "async") -> None: @@ -93,9 +82,9 @@ async def main(mode: str = "async") -> None: print("=== Running Synchronous Example ===") try: main_sync() - print("✅ Sync example completed successfully!") + print("Sync example completed successfully!") except Exception as e: - print(f"❌ Sync example failed: {e}") + print(f"Sync example failed: {e}") if mode in ("async", "both"): if mode == "both": @@ -104,9 +93,9 @@ async def main(mode: str = "async") -> None: print("=== Running Asynchronous Example ===") try: await main_async() - print("✅ Async example completed successfully!") + print("Async example completed successfully!") except Exception as e: - print(f"❌ Async example failed: {e}") + print(f"Async example failed: {e}") if __name__ == "__main__": diff --git a/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py b/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py index e9c9e72..1231703 100644 --- a/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py +++ b/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py @@ -9,11 +9,10 @@ import os import sys import time -from datetime import datetime +from datetime import datetime from dotenv import load_dotenv from psycopg_pool import AsyncConnectionPool, ConnectionPool - from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection # Load environment variables from .env file @@ -29,7 +28,11 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - # Create connection pool with Entra authentication + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect pool = ConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, @@ -37,48 +40,37 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: open=False, connection_class=EntraConnection, ) - pool.open() execution_count = 0 - - try: - with pool: - while True: - execution_count += 1 - current_time = datetime.now().strftime("%H:%M:%S") - - print(f"Sync Execution #{execution_count} at {current_time}") - - try: - with pool.connection() as conn, conn.cursor() as cur: - # Query 1: Get PostgreSQL version - cur.execute("SELECT version()") - version = cur.fetchone() - print( - f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." - ) - - # Query 2: Get current user - cur.execute("SELECT current_user") - user = cur.fetchone() - print(f"Connected as: {user[0] if user else 'Unknown'}") - - # Query 3: Get current timestamp - cur.execute("SELECT now()") - timestamp = cur.fetchone() - print( - f"Server time: {timestamp[0] if timestamp else 'Unknown'}" - ) - - print("Sync query execution successful!") - - except Exception as e: - print(f"Database error: {e}") - - print(f"Waiting {interval_minutes} minutes until next execution...\n") - time.sleep(interval_minutes * 60) - finally: - pool.close() + with pool: + # Get connection from pool + conn = pool.getconn() + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + print(f"Sync Execution #{execution_count} at {current_time}") + with conn.cursor() as cur: + # Query 1: Get PostgreSQL version + cur.execute("SELECT version()") + version = cur.fetchone() + print( + f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." + ) + + # Query 2: Get current user + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Connected as: {user[0] if user else 'Unknown'}") + + # Query 3: Get current timestamp + cur.execute("SELECT now()") + timestamp = cur.fetchone() + print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") + + print("Sync query execution successful!") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: @@ -88,7 +80,11 @@ async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - # Create async connection pool with Entra authentication + # We use the AsyncEntraConnection class to enable asynchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect pool = AsyncConnectionPool( conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", min_size=1, @@ -96,48 +92,38 @@ async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: open=False, connection_class=AsyncEntraConnection, ) - await pool.open() execution_count = 0 - - try: - async with pool: - while True: - execution_count += 1 - current_time = datetime.now().strftime("%H:%M:%S") - - print(f"Async Execution #{execution_count} at {current_time}") - - try: - async with pool.connection() as conn, conn.cursor() as cur: - # Query 1: Get PostgreSQL version - await cur.execute("SELECT version()") - version = await cur.fetchone() - print( - f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." - ) - - # Query 2: Get current user - await cur.execute("SELECT current_user") - user = await cur.fetchone() - print(f"Connected as: {user[0] if user else 'Unknown'}") - - # Query 3: Get current timestamp - await cur.execute("SELECT now()") - timestamp = await cur.fetchone() - print( - f"Server time: {timestamp[0] if timestamp else 'Unknown'}" - ) - - print("Async query execution successful!") - - except Exception as e: - print(f"Database error: {e}") - - print(f"Waiting {interval_minutes} minutes until next execution...\n") - await asyncio.sleep(interval_minutes * 60) - finally: - await pool.close() + async with pool: + # Get connection from pool + conn = await pool.getconn() + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Async Execution #{execution_count} at {current_time}") + async with conn.cursor() as cur: + # Query 1: Get PostgreSQL version + await cur.execute("SELECT version()") + version = await cur.fetchone() + print( + f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." + ) + + # Query 2: Get current user + await cur.execute("SELECT current_user") + user = await cur.fetchone() + print(f"Connected as: {user[0] if user else 'Unknown'}") + + # Query 3: Get current timestamp + await cur.execute("SELECT now()") + timestamp = await cur.fetchone() + print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") + + print("Async query execution successful!") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) async def main() -> None: @@ -164,11 +150,6 @@ async def main() -> None: print("Error: POSTGRES_SERVER environment variable is required") sys.exit(1) - print(f"Target server: {SERVER}") - print(f"Target database: {DATABASE}") - print(f"Query interval: {args.interval} minutes") - print(f"Mode: {args.mode}\n") - if args.mode in ("sync", "both"): run_everlasting_sync_queries(args.interval) diff --git a/python/samples/sqlalchemy/getting_started/create_db_connection.py b/python/samples/sqlalchemy/getting_started/create_db_connection.py index b2dab5e..12d775d 100644 --- a/python/samples/sqlalchemy/getting_started/create_db_connection.py +++ b/python/samples/sqlalchemy/getting_started/create_db_connection.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv from sqlalchemy import create_engine, text -from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine from azurepg_entra.sqlalchemy import ( enable_entra_authentication, @@ -26,84 +26,59 @@ def main_sync() -> None: """Synchronous connection example using SQLAlchemy with Entra ID authentication.""" - try: - # Create a synchronous engine - engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + # Create a synchronous engine + engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") - # We add an event listener to the engine to enable synchronous Entra authentication - # for database access. This event listener is triggered whenever the connection pool - # backing the engine creates a new connection, ensuring that Entra authentication tokens - # are properly managed and refreshed so that each connection uses a valid token. - # - # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function - enable_entra_authentication(engine) + # We add an event listener to the engine to enable synchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication(engine) - with engine.connect() as conn: - # Query 1 - result = conn.execute(text("SELECT now()")) - row = result.fetchone() - print(f"Sync - Database time: {row[0] if row else 'Unknown'}") + with engine.connect() as conn: + # Query 1 + result = conn.execute(text("SELECT now()")) + row = result.fetchone() + print(f"Sync - Database time: {row[0] if row else 'Unknown'}") - # Query 2 - result = conn.execute(text("SELECT current_user")) - row = result.fetchone() - print(f"Sync - Connected as: {row[0] if row else 'Unknown'}") + # Query 2 + result = conn.execute(text("SELECT current_user")) + row = result.fetchone() + print(f"Sync - Connected as: {row[0] if row else 'Unknown'}") - # Clean up the engine - engine.dispose() - except Exception as e: - print(f"Sync - Error connecting to database: {e}") - raise + # Clean up the engine + engine.dispose() async def main_async() -> None: """Asynchronous connection example using SQLAlchemy with Entra ID authentication.""" - try: - # Create an asynchronous engine - engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") - - # We add an event listener to the engine to enable asynchronous Entra authentication - # for database access. This event listener is triggered whenever the connection pool - # backing the engine creates a new connection, ensuring that Entra authentication tokens - # are properly managed and refreshed so that each connection uses a valid token. - # - # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function - enable_entra_authentication_async(engine) - - # Core usage example - async with engine.connect() as conn: - # Query 1 - result = await conn.execute(text("SELECT now()")) - row = result.fetchone() - print(f"Async Core - Database time: {row[0] if row else 'Unknown'}") - - # Query 2 - result = await conn.execute(text("SELECT current_user")) - row = result.fetchone() - print(f"Async Core - Connected as: {row[0] if row else 'Unknown'}") - - # ORM usage example with async_sessionmaker - AsyncSession = async_sessionmaker(engine, expire_on_commit=False) - - async with AsyncSession() as session: - # Query 1 - result = await session.execute(text("SELECT current_database()")) - db_name = result.scalar() - print(f"Async ORM - Connected to database: {db_name}") - - # Query 2 - result = await session.execute(text("SELECT version()")) - version = result.scalar() - # Just show the first part of the version string for cleaner output - version_short = version.split(' on ')[0] if version else 'Unknown' - print(f"Async ORM - PostgreSQL version: {version_short}") - - # Clean up the engine - await engine.dispose() - except Exception as e: - print(f"Async - Error connecting to database: {e}") - raise + # Create an asynchronous engine + engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable asynchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication_async(engine) + + async with engine.connect() as conn: + # Query 1 + result = await conn.execute(text("SELECT now()")) + row = result.fetchone() + print(f"Async Core - Database time: {row[0] if row else 'Unknown'}") + + # Query 2 + result = await conn.execute(text("SELECT current_user")) + row = result.fetchone() + print(f"Async Core - Connected as: {row[0] if row else 'Unknown'}") + + # Clean up the engine + await engine.dispose() async def main(mode: str = "async") -> None: @@ -116,9 +91,9 @@ async def main(mode: str = "async") -> None: print("=== Running Synchronous SQLAlchemy Example ===") try: main_sync() - print("✅ Sync example completed successfully!") + print("Sync example completed successfully!") except Exception as e: - print(f"❌ Sync example failed: {e}") + print(f"Sync example failed: {e}") if mode in ("async", "both"): if mode == "both": @@ -127,9 +102,9 @@ async def main(mode: str = "async") -> None: print("=== Running Asynchronous SQLAlchemy Example ===") try: await main_async() - print("✅ Async example completed successfully!") + print("Async example completed successfully!") except Exception as e: - print(f"❌ Async example failed: {e}") + print(f"Async example failed: {e}") if __name__ == "__main__": diff --git a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py index f11bae3..2732326 100644 --- a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py +++ b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py @@ -33,21 +33,29 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - # Create synchronous engine with Entra authentication + # Create synchronous engine engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable synchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function enable_entra_authentication(engine) execution_count = 0 - try: - while True: - execution_count += 1 - current_time = datetime.now().strftime("%H:%M:%S") + # Get one connection and reuse it throughout the program + with engine.connect() as conn: + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") - print(f"Sync Execution #{execution_count} at {current_time}") + print(f"Sync Execution #{execution_count} at {current_time}") - try: - with engine.connect() as conn: + try: # Query 1: Get PostgreSQL version result = conn.execute(text("SELECT version()")) row = result.fetchone() @@ -68,13 +76,13 @@ def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: print("Sync query execution successful!") - except Exception as e: - print(f"Database error: {e}") + except Exception as e: + print(f"Database error: {e}") - print(f"Waiting {interval_minutes} minutes until next execution...\n") - time.sleep(interval_minutes * 60) - finally: - engine.dispose() + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + engine.dispose() async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: @@ -84,21 +92,29 @@ async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: print(f"Running queries every {interval_minutes} minutes...") print("Press Ctrl+C to stop\n") - # Create asynchronous engine with Entra authentication + # Create asynchronous engine engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable asynchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function enable_entra_authentication_async(engine) execution_count = 0 - try: - while True: - execution_count += 1 - current_time = datetime.now().strftime("%H:%M:%S") + # Get one connection and reuse it throughout the program + async with engine.connect() as conn: + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") - print(f"Async Execution #{execution_count} at {current_time}") + print(f"Async Execution #{execution_count} at {current_time}") - try: - async with engine.connect() as conn: + try: # Query 1: Get PostgreSQL version result = await conn.execute(text("SELECT version()")) row = result.fetchone() @@ -119,13 +135,13 @@ async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: print("Async query execution successful!") - except Exception as e: - print(f"Database error: {e}") + except Exception as e: + print(f"Database error: {e}") - print(f"Waiting {interval_minutes} minutes until next execution...\n") - await asyncio.sleep(interval_minutes * 60) - finally: - await engine.dispose() + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + await engine.dispose() async def main() -> None: diff --git a/python/src/azurepg_entra/__init__.py b/python/src/azurepg_entra/__init__.py index 5d3a282..a6d8259 100644 --- a/python/src/azurepg_entra/__init__.py +++ b/python/src/azurepg_entra/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. + """ Azure PostgreSQL Entra ID Integration Library diff --git a/python/src/azurepg_entra/core.py b/python/src/azurepg_entra/core.py index ee1844c..3eb2ad8 100644 --- a/python/src/azurepg_entra/core.py +++ b/python/src/azurepg_entra/core.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + import base64 import json from typing import Any, cast diff --git a/python/src/azurepg_entra/errors.py b/python/src/azurepg_entra/errors.py index 4478222..3782e64 100644 --- a/python/src/azurepg_entra/errors.py +++ b/python/src/azurepg_entra/errors.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft. All rights reserved. + class AzurePgEntraError(Exception): """Base class for all custom exceptions in the project.""" diff --git a/python/src/azurepg_entra/psycopg2/__init__.py b/python/src/azurepg_entra/psycopg2/__init__.py index 13339c1..9b223bf 100644 --- a/python/src/azurepg_entra/psycopg2/__init__.py +++ b/python/src/azurepg_entra/psycopg2/__init__.py @@ -1,30 +1,19 @@ # Copyright (c) Microsoft. All rights reserved. + """ Psycopg2 support for Azure Entra ID authentication with Azure Database for PostgreSQL. -This module provides connection classes that handle Azure Entra ID token acquisition -and authentication for synchronous (psycopg2) PostgreSQL connections. +This module provides a connection class that handles Azure Entra ID token acquisition +and authentication for synchronous PostgreSQL connections. Requirements: Install with: pip install azurepg-entra[psycopg2] This will install: - - psycopg2-binary>=2.8.0 + - psycopg2-binary>=2.9.0 Classes: - EntraConnection: Synchronous connection class with Entra ID authentication (psycopg2) - -Example usage: - # Synchronous connection - from azurepg_entra.psycopg2 import EntraConnection - - connection_pool = pool.ThreadedConnectionPool( - minconn=1, - maxconn=5, - host=SERVER, - database=DATABASE, - connection_factory=EntraConnection - ) + EntraConnection: Synchronous connection class with Entra ID authentication """ from .entra_connection import ( diff --git a/python/src/azurepg_entra/psycopg2/entra_connection.py b/python/src/azurepg_entra/psycopg2/entra_connection.py index d7f9fbd..5e6523c 100644 --- a/python/src/azurepg_entra/psycopg2/entra_connection.py +++ b/python/src/azurepg_entra/psycopg2/entra_connection.py @@ -1,7 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any +from typing import Any from azure.core.credentials import TokenCredential +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, +) try: from psycopg2.extensions import connection, make_dsn, parse_dsn @@ -12,13 +17,6 @@ "Install them with: pip install azurepg-entra[psycopg2]" ) from e -from azurepg_entra.core import get_entra_conninfo -from azurepg_entra.errors import ( - CredentialValueError, - EntraConnectionValueError, -) - - class EntraConnection(connection): """Establishes a synchronous PostgreSQL connection using Entra authentication. diff --git a/python/src/azurepg_entra/psycopg3/__init__.py b/python/src/azurepg_entra/psycopg3/__init__.py index 5d13eb9..86f8d8e 100644 --- a/python/src/azurepg_entra/psycopg3/__init__.py +++ b/python/src/azurepg_entra/psycopg3/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. + """ -Psycopg3 (psycopg) support for Azure Entra ID authentication with Azure Database for PostgreSQL. +Psycopg3 support for Azure Entra ID authentication with Azure Database for PostgreSQL. This module provides connection classes that extend psycopg's Connection and AsyncConnection to automatically handle Azure Entra ID token acquisition and authentication. @@ -10,26 +11,11 @@ This will install: - psycopg[binary]>=3.1.0 + - aiohttp>=3.8.0 Classes: EntraConnection: Synchronous connection class with Entra ID authentication AsyncEntraConnection: Asynchronous connection class with Entra ID authentication - -Example usage: - from azurepg_entra.psycopg3 import EntraConnection, AsyncEntraConnection - from psycopg_pool import ConnectionPool, AsyncConnectionPool - - # Synchronous usage - pool = ConnectionPool( - conninfo="postgresql://myserver:5432/mydb", - connection_class=EntraConnection - ) - - # Asynchronous usage - async_pool = AsyncConnectionPool( - conninfo="postgresql://myserver:5432/mydb", - connection_class=AsyncEntraConnection - ) """ from .async_entra_connection import AsyncEntraConnection diff --git a/python/src/azurepg_entra/psycopg3/async_entra_connection.py b/python/src/azurepg_entra/psycopg3/async_entra_connection.py index 5eaaefc..2e86443 100644 --- a/python/src/azurepg_entra/psycopg3/async_entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/async_entra_connection.py @@ -1,13 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any - from azure.core.credentials_async import AsyncTokenCredential try: from psycopg import AsyncConnection except ImportError as e: - # Provide a helpful error message if psycopg3 dependencies are missing raise ImportError( "psycopg3 dependencies are not installed. " "Install them with: pip install azurepg-entra[psycopg3]" @@ -19,13 +17,8 @@ EntraConnectionValueError, ) -try: - from typing import Self -except ImportError: - from typing_extensions import Self # fallback for older Python - -class AsyncEntraConnection(AsyncConnection[tuple[Any, ...]]): +class AsyncEntraConnection(AsyncConnection): """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" @classmethod @@ -60,7 +53,7 @@ async def connect(cls, *args: Any, **kwargs: Any) -> "AsyncEntraConnection": if not kwargs.get("user") or not kwargs.get("password"): try: entra_conninfo = await get_entra_conninfo_async(credential) - except (Exception) as e: + except Exception as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/src/azurepg_entra/psycopg3/entra_connection.py b/python/src/azurepg_entra/psycopg3/entra_connection.py index 3da72e6..2ca523f 100644 --- a/python/src/azurepg_entra/psycopg3/entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/entra_connection.py @@ -6,7 +6,6 @@ try: from psycopg import Connection except ImportError as e: - # Provide a helpful error message if psycopg3 dependencies are missing raise ImportError( "psycopg3 dependencies are not installed. " "Install them with: pip install azurepg-entra[psycopg3]" @@ -19,7 +18,7 @@ ) -class EntraConnection(Connection[tuple[Any, ...]]): +class EntraConnection(Connection): """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" @classmethod @@ -54,7 +53,7 @@ def connect(cls, *args: Any, **kwargs: Any) -> "EntraConnection": if not kwargs.get("user") or not kwargs.get("password"): try: entra_conninfo = get_entra_conninfo(credential) - except (Exception) as e: + except Exception as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/src/azurepg_entra/sqlalchemy/__init__.py b/python/src/azurepg_entra/sqlalchemy/__init__.py index 30aca3f..5767946 100644 --- a/python/src/azurepg_entra/sqlalchemy/__init__.py +++ b/python/src/azurepg_entra/sqlalchemy/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. + """ SQLAlchemy integration for Azure PostgreSQL with Entra ID authentication. @@ -6,24 +7,16 @@ authentication for PostgreSQL connections. It automatically handles token acquisition and credential injection through SQLAlchemy's event system. -Usage: - Synchronous engines: - from sqlalchemy import create_engine - from azurepg_entra.sqlalchemy import enable_entra_authentication - - engine = create_engine("postgresql://myserver.postgres.database.azure.com/mydb") - enable_entra_authentication(engine) - - Asynchronous engines: - from sqlalchemy.ext.asyncio import create_async_engine - from azurepg_entra.sqlalchemy import enable_entra_authentication_async +Requirements: + Install with: pip install azurepg-entra[sqlalchemy] - engine = create_async_engine("postgresql+asyncpg://myserver.postgres.database.azure.com/mydb") - enable_entra_authentication_async(engine) + This will install: + - sqlalchemy>=2.0.0 + - aiohttp>=3.8.0 Functions: - enable_entra_authentication: Enable Entra ID auth for synchronous SQLAlchemy engines - enable_entra_authentication_async: Enable Entra ID auth for asynchronous SQLAlchemy engines + enable_entra_authentication: Enable Entra ID authentication for synchronous SQLAlchemy engines + enable_entra_authentication_async: Enable Entra ID authentication for asynchronous SQLAlchemy engines """ from .async_entra_connection import enable_entra_authentication_async diff --git a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py index 20f63b8..cd56b3d 100644 --- a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py @@ -1,7 +1,7 @@ -from typing import Any +# Copyright (c) Microsoft. All rights reserved. +from typing import Any from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential try: from sqlalchemy import event @@ -35,21 +35,19 @@ def enable_entra_authentication_async(engine: AsyncEngine) -> None: """ @event.listens_for(engine.sync_engine, "do_connect") - def provide_token_async( + def provide_token( dialect: Dialect, conn_rec: Any, cargs: Any, cparams: dict[str, Any] ) -> None: - """Event handler that provides Entra credentials for each async connection. + """Event handler that provides Entra credentials for each sync connection. Raises: CredentialValueError: If the provided credential is not a valid TokenCredential. EntraConnectionValueError: If Entra connection credentials cannot be retrieved """ credential = cparams.get("credential", None) - if credential and not isinstance( - credential, (AsyncTokenCredential, TokenCredential) - ): + if credential and not isinstance(credential, (TokenCredential)): raise CredentialValueError( - "credential must be an AsyncTokenCredential or TokenCredential for async connections" + "credential must be a TokenCredential for async connections" ) # Check if credentials are already present has_user = "user" in cparams @@ -58,14 +56,8 @@ def provide_token_async( # Only get Entra credentials if user or password is missing if not has_user or not has_password: try: - # Cast to TokenCredential since SQLAlchemy events are synchronous - sync_credential: TokenCredential | None = ( - credential - if isinstance(credential, TokenCredential) or credential is None - else None - ) - entra_creds = get_entra_conninfo(sync_credential) - except (Exception) as e: + entra_creds = get_entra_conninfo(credential) + except Exception as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/src/azurepg_entra/sqlalchemy/entra_connection.py b/python/src/azurepg_entra/sqlalchemy/entra_connection.py index 972ab3f..2490260 100644 --- a/python/src/azurepg_entra/sqlalchemy/entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/entra_connection.py @@ -1,12 +1,12 @@ -from typing import Any +# Copyright (c) Microsoft. All rights reserved. +from typing import Any from azure.core.credentials import TokenCredential try: from sqlalchemy import Engine, event from sqlalchemy.engine import Dialect except ImportError as e: - # Provide a helpful error message if SQLAlchemy dependencies are missing raise ImportError( "SQLAlchemy dependencies are not installed. " "Install them with: pip install azurepg-entra[sqlalchemy]" @@ -53,7 +53,7 @@ def provide_token( if not has_user or not has_password: try: entra_creds = get_entra_conninfo(credential) - except (Exception) as e: + except Exception as e: raise EntraConnectionValueError( "Could not retrieve Entra credentials" ) from e diff --git a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py similarity index 73% rename from python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py rename to python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py index f7ed013..6c670cb 100644 --- a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py @@ -1,14 +1,25 @@ # Copyright (c) Microsoft. All rights reserved. + from unittest.mock import patch -import jwt +import json +import base64 import pytest from psycopg2.extensions import make_dsn, parse_dsn def create_test_token(payload): - """Helper to create a test JWT token.""" - return jwt.encode(payload, key="", algorithm="none") + """Helper to create a test JWT token manually.""" + # Create a simple JWT-like token with header.payload.signature format + header = {"alg": "none", "typ": "JWT"} + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) + signature = "" + return f"{header_encoded}.{payload_encoded}.{signature}" class TestEntraConnection: @@ -23,11 +34,9 @@ def test_dsn_processing_adds_entra_credentials(self): "password": token, } - from azurepg_entra.core import get_entra_conninfo - # Test with existing DSN parameters original_dsn = "host=localhost port=5432 dbname=testdb sslmode=require" - entra_creds = get_entra_conninfo(None) + entra_creds = mock_get_creds(None) dsn_params = parse_dsn(original_dsn) if original_dsn else {} dsn_params.update(entra_creds) diff --git a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg3/test_entra_id_extension.py similarity index 99% rename from python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py rename to python/tests/azure/data/postgresql/psycopg3/test_entra_id_extension.py index f730c84..c93214f 100644 --- a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg3/test_entra_id_extension.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. + from unittest.mock import AsyncMock, Mock, patch import pytest diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py similarity index 100% rename from python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py rename to python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py index 7590f0b..f841297 100644 --- a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch import pytest diff --git a/python/tests/azure/data/postgresql/test_core_functionality.py b/python/tests/azure/data/postgresql/test_core_functionality.py index bb2a554..5f46818 100644 --- a/python/tests/azure/data/postgresql/test_core_functionality.py +++ b/python/tests/azure/data/postgresql/test_core_functionality.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. + import base64 import json from unittest.mock import AsyncMock, Mock, patch From 4b2f09d18fb82ff6ed45ba73a1447c1eab8e5f97 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Mon, 13 Oct 2025 21:09:10 -0700 Subject: [PATCH 14/19] Fix formatting and workflow script --- .../psycopg2/entra_connection.py | 2 ++ .../psycopg3/async_entra_connection.py | 1 + .../psycopg3/entra_connection.py | 1 + .../sqlalchemy/async_entra_connection.py | 1 + .../sqlalchemy/entra_connection.py | 1 + .../psycopg2/test_entra_id_extension.py | 4 ++-- .../sqlalchemy/test_entra_id_extension.py | 1 + run-python-checks.ps1 | 24 +++++++++++++++++-- 8 files changed, 31 insertions(+), 4 deletions(-) diff --git a/python/src/azurepg_entra/psycopg2/entra_connection.py b/python/src/azurepg_entra/psycopg2/entra_connection.py index 5e6523c..cfb21d4 100644 --- a/python/src/azurepg_entra/psycopg2/entra_connection.py +++ b/python/src/azurepg_entra/psycopg2/entra_connection.py @@ -1,7 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any + from azure.core.credentials import TokenCredential + from azurepg_entra.core import get_entra_conninfo from azurepg_entra.errors import ( CredentialValueError, diff --git a/python/src/azurepg_entra/psycopg3/async_entra_connection.py b/python/src/azurepg_entra/psycopg3/async_entra_connection.py index 2e86443..f694020 100644 --- a/python/src/azurepg_entra/psycopg3/async_entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/async_entra_connection.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any + from azure.core.credentials_async import AsyncTokenCredential try: diff --git a/python/src/azurepg_entra/psycopg3/entra_connection.py b/python/src/azurepg_entra/psycopg3/entra_connection.py index 2ca523f..9197736 100644 --- a/python/src/azurepg_entra/psycopg3/entra_connection.py +++ b/python/src/azurepg_entra/psycopg3/entra_connection.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any + from azure.core.credentials import TokenCredential try: diff --git a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py index cd56b3d..fc4848c 100644 --- a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any + from azure.core.credentials import TokenCredential try: diff --git a/python/src/azurepg_entra/sqlalchemy/entra_connection.py b/python/src/azurepg_entra/sqlalchemy/entra_connection.py index 2490260..f4cf80c 100644 --- a/python/src/azurepg_entra/sqlalchemy/entra_connection.py +++ b/python/src/azurepg_entra/sqlalchemy/entra_connection.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from typing import Any + from azure.core.credentials import TokenCredential try: diff --git a/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py index 6c670cb..4a51038 100644 --- a/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. +import base64 +import json from unittest.mock import patch -import json -import base64 import pytest from psycopg2.extensions import make_dsn, parse_dsn diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py index f841297..f078704 100644 --- a/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py +++ b/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from unittest.mock import Mock, patch + import pytest diff --git a/run-python-checks.ps1 b/run-python-checks.ps1 index 1f443c9..56c36da 100644 --- a/run-python-checks.ps1 +++ b/run-python-checks.ps1 @@ -101,8 +101,28 @@ try { if (Test-Path "tests") { Write-Host "Running pytest" -ForegroundColor Blue - if ($Verbose) { & $venvPython -m pytest tests -v } else { & $venvPython -m pytest tests -q *> $null } - Write-CheckResult "pytest" ($LASTEXITCODE -eq 0) + + # Run tests for each subdirectory separately to avoid import collisions + $testDirs = @( + "tests/azure/data/postgresql/psycopg2", + "tests/azure/data/postgresql/psycopg3", + "tests/azure/data/postgresql/sqlalchemy", + "tests/azure/data/postgresql/test_core_functionality.py" + ) + + $allTestsPass = $true + foreach ($testDir in $testDirs) { + if (Test-Path $testDir) { + Write-Host " Testing $testDir" -ForegroundColor Gray + if ($Verbose) { + & $venvPython -m pytest $testDir -v + } else { + & $venvPython -m pytest $testDir -q *> $null + } + if ($LASTEXITCODE -ne 0) { $allTestsPass = $false } + } + } + Write-CheckResult "pytest" $allTestsPass } else { Write-Host "WARN No tests directory present" -ForegroundColor Yellow } From 597504a263a2bfcd77821b7bafb492daf44ef564 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Mon, 13 Oct 2025 21:19:37 -0700 Subject: [PATCH 15/19] Fix Python workflow file --- run-python-checks.ps1 | 30 +++++++----------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/run-python-checks.ps1 b/run-python-checks.ps1 index 56c36da..4262bb6 100644 --- a/run-python-checks.ps1 +++ b/run-python-checks.ps1 @@ -91,10 +91,6 @@ try { if ($Verbose) { & $venvPython -m ruff check ./src ./tests } else { & $venvPython -m ruff check ./src ./tests *> $null } Write-CheckResult "ruff lint" ($LASTEXITCODE -eq 0) - # mypy target - if ($Verbose) { & $venvPython -m mypy ./src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py } else { & $venvPython -m mypy ./src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py *> $null } - Write-CheckResult "mypy (target)" ($LASTEXITCODE -eq 0) - # mypy all if ($Verbose) { & $venvPython -m mypy ./src/azurepg_entra/ } else { & $venvPython -m mypy ./src/azurepg_entra/ *> $null } Write-CheckResult "mypy (all)" ($LASTEXITCODE -eq 0) @@ -102,26 +98,14 @@ try { if (Test-Path "tests") { Write-Host "Running pytest" -ForegroundColor Blue - # Run tests for each subdirectory separately to avoid import collisions - $testDirs = @( - "tests/azure/data/postgresql/psycopg2", - "tests/azure/data/postgresql/psycopg3", - "tests/azure/data/postgresql/sqlalchemy", - "tests/azure/data/postgresql/test_core_functionality.py" - ) - - $allTestsPass = $true - foreach ($testDir in $testDirs) { - if (Test-Path $testDir) { - Write-Host " Testing $testDir" -ForegroundColor Gray - if ($Verbose) { - & $venvPython -m pytest $testDir -v - } else { - & $venvPython -m pytest $testDir -q *> $null - } - if ($LASTEXITCODE -ne 0) { $allTestsPass = $false } - } + # Use importlib mode to avoid import collisions from files with same basename + # This allows pytest to handle multiple test_entra_id_extension.py files + if ($Verbose) { + & $venvPython -m pytest tests --import-mode=importlib -v + } else { + & $venvPython -m pytest tests --import-mode=importlib -q *> $null } + $allTestsPass = ($LASTEXITCODE -eq 0) Write-CheckResult "pytest" $allTestsPass } else { Write-Host "WARN No tests directory present" -ForegroundColor Yellow From 093ae17434533ce0fdb2f5f555d1b850eb755361 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Mon, 13 Oct 2025 21:24:36 -0700 Subject: [PATCH 16/19] Update Python workflow script --- .github/workflows/pr-python.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-python.yml b/.github/workflows/pr-python.yml index d91fd43..b699bc2 100644 --- a/.github/workflows/pr-python.yml +++ b/.github/workflows/pr-python.yml @@ -24,8 +24,8 @@ jobs: run: python -m ruff check src tests - name: Type check working-directory: python - run: python -m mypy src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py + run: python -m mypy src/azurepg_entra/ - name: Tests if: ${{ always() }} # adjust if you add tests working-directory: python - run: python -m pytest -q \ No newline at end of file + run: python -m pytest tests --import-mode=importlib -q \ No newline at end of file From 6db50b5cd98b4aa88e3501e20cee05b688d6fdea Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Mon, 13 Oct 2025 21:39:04 -0700 Subject: [PATCH 17/19] Fix file naming in solution file --- dotnet/dotnet.sln | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/dotnet.sln b/dotnet/dotnet.sln index 36674be..8c103bd 100644 --- a/dotnet/dotnet.sln +++ b/dotnet/dotnet.sln @@ -22,7 +22,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Postgresql", "Postgresql", EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Npgsql", "Npgsql", "{93934517-16C9-C51A-8F2B-54760F50BDEB}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql", "src\Azure\Data\Postgresql\Npgsql\Azure.Data.Postgresql.Npgsql.csproj", "{3E862DB4-B843-4361-94B5-8CF34402B511}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql", "src\Azure\Data\PostgreSql\Npgsql\Azure.Data.Postgresql.Npgsql.csproj", "{3E862DB4-B843-4361-94B5-8CF34402B511}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{0AB3BF05-4346-4AA6-1389-037BE0695223}" EndProject @@ -34,7 +34,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Postgresql", "Postgresql", EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Npgsql", "Npgsql", "{290860F1-0C73-540D-3A79-AA6C3ABBD9C3}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql.Tests", "tests\Azure\Data\Postgresql\Npgsql\Azure.Data.Postgresql.Npgsql.Tests.csproj", "{750B2A4F-9EF5-4CC5-8EF9-A93F4A1748F6}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql.Tests", "tests\Azure\Data\PostgreSql\Npgsql\Azure.Data.Postgresql.Npgsql.Tests.csproj", "{750B2A4F-9EF5-4CC5-8EF9-A93F4A1748F6}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution From c7cd50ff98a3abe6139ab13e2e3d8416c77aa776 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Mon, 13 Oct 2025 21:43:36 -0700 Subject: [PATCH 18/19] Fix dotnet script --- dotnet/dotnet.sln | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/dotnet.sln b/dotnet/dotnet.sln index 8c103bd..9205c04 100644 --- a/dotnet/dotnet.sln +++ b/dotnet/dotnet.sln @@ -34,7 +34,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Postgresql", "Postgresql", EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Npgsql", "Npgsql", "{290860F1-0C73-540D-3A79-AA6C3ABBD9C3}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql.Tests", "tests\Azure\Data\PostgreSql\Npgsql\Azure.Data.Postgresql.Npgsql.Tests.csproj", "{750B2A4F-9EF5-4CC5-8EF9-A93F4A1748F6}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql.Tests", "tests\Azure\Data\Postgresql\Npgsql\Azure.Data.Postgresql.Npgsql.Tests.csproj", "{750B2A4F-9EF5-4CC5-8EF9-A93F4A1748F6}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution From a62a171050e8e14b2f23c9080f3187977081bd75 Mon Sep 17 00:00:00 2001 From: Arjun Narendra Date: Tue, 14 Oct 2025 09:58:07 -0700 Subject: [PATCH 19/19] Update scripting to only focus on Python and not .NET --- .github/workflows/pr-dotnet.yml | 25 ------ run-dotnet-checks.ps1 | 78 ------------------- .../run-python-checks.ps1 | 21 ++--- 3 files changed, 8 insertions(+), 116 deletions(-) delete mode 100644 .github/workflows/pr-dotnet.yml delete mode 100644 run-dotnet-checks.ps1 rename run-python-checks.ps1 => scripts/run-python-checks.ps1 (80%) diff --git a/.github/workflows/pr-dotnet.yml b/.github/workflows/pr-dotnet.yml deleted file mode 100644 index baa703c..0000000 --- a/.github/workflows/pr-dotnet.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: PR .NET Checks -on: - pull_request: - branches: [ main ] - paths: - - 'dotnet/**' - - '.github/workflows/pr-dotnet.yml' - -jobs: - dotnet-quality: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '9.0.x' - - name: Restore - working-directory: dotnet - run: dotnet restore - - name: Build - working-directory: dotnet - run: dotnet build --configuration Release --no-restore - - name: Test - working-directory: dotnet - run: dotnet test --configuration Release --no-build --verbosity normal \ No newline at end of file diff --git a/run-dotnet-checks.ps1 b/run-dotnet-checks.ps1 deleted file mode 100644 index dda71f8..0000000 --- a/run-dotnet-checks.ps1 +++ /dev/null @@ -1,78 +0,0 @@ -param( - [string]$Configuration = "Release", - [switch]$Verbose, - [switch]$Help -) - -<# -.SYNOPSIS - Run .NET quality checks (restore, build, test) locally. - -.DESCRIPTION - Mirrors the CI steps in pr-dotnet.yml for confidence before pushing. - -.PARAMETER Configuration - Build configuration (Release or Debug). Defaults to Release to match CI. - -.PARAMETER Verbose - Show full command output (otherwise minimal). - -.EXAMPLE - ./run-dotnet-checks.ps1 - -.EXAMPLE - ./run-dotnet-checks.ps1 -Configuration Debug -#> - -if ($Help) { - Get-Help -Detailed -ErrorAction SilentlyContinue - Write-Host "Usage: ./run-dotnet-checks.ps1 [-Configuration Release|Debug] [-Verbose]" -ForegroundColor Cyan - exit 0 -} - -function Write-CheckResult { - param([string]$Name, [bool]$Success, [string]$Message = "") - if ($Success) { - Write-Host "PASS $Name" -ForegroundColor Green - if ($Message) { Write-Host " $Message" -ForegroundColor Gray } - } else { - Write-Host "FAIL $Name" -ForegroundColor Red - if ($Message) { Write-Host " $Message" -ForegroundColor Gray } - $script:OverallSuccess = $false - } -} - -$OverallSuccess = $true -$dotnetRoot = Join-Path (Get-Location) "dotnet" -if (-not (Test-Path $dotnetRoot)) { - Write-Host "dotnet/ directory not found" -ForegroundColor Red - exit 1 -} - -Push-Location $dotnetRoot -try { - if (-not (Get-Command dotnet -ErrorAction SilentlyContinue)) { - Write-Host ".NET SDK not found in PATH" -ForegroundColor Red - exit 1 - } - - $sdkVersion = dotnet --version - Write-Host "Using .NET SDK $sdkVersion (Configuration=$Configuration)" - - Write-Host "Restoring packages" -ForegroundColor Blue - if ($Verbose) { dotnet restore } else { dotnet restore --verbosity minimal } - Write-CheckResult "restore" ($LASTEXITCODE -eq 0) - - Write-Host "Building solution" -ForegroundColor Blue - if ($Verbose) { dotnet build --no-restore --configuration $Configuration } else { dotnet build --no-restore --configuration $Configuration --verbosity minimal } - Write-CheckResult "build" ($LASTEXITCODE -eq 0) - - Write-Host "Running tests" -ForegroundColor Blue - if ($Verbose) { dotnet test --no-build --configuration $Configuration --verbosity normal } else { dotnet test --no-build --configuration $Configuration --verbosity minimal } - Write-CheckResult "test" ($LASTEXITCODE -eq 0) -} -finally { - Pop-Location -} - -if (-not $OverallSuccess) { exit 1 } diff --git a/run-python-checks.ps1 b/scripts/run-python-checks.ps1 similarity index 80% rename from run-python-checks.ps1 rename to scripts/run-python-checks.ps1 index 4262bb6..603b33a 100644 --- a/run-python-checks.ps1 +++ b/scripts/run-python-checks.ps1 @@ -9,7 +9,7 @@ param( Run Python quality checks (lint, type, tests) locally. .DESCRIPTION - Mirrors the CI steps defined in pr-python.yml: install deps (.[all]), Ruff lint, mypy (target + package), pytest, and import validation. + Mirrors the CI steps defined in pr-python.yml: install deps (.[all]), Ruff lint, mypy, and pytest. .PARAMETER Verbose Show full tool output instead of suppressing it. @@ -78,22 +78,20 @@ try { & $venvPython -m pip install --upgrade pip | Out-Null Write-CheckResult "pip upgrade" ($LASTEXITCODE -eq 0) - Write-Host "Installing project deps (editable .[all])" -ForegroundColor Blue - & $venvPython -m pip install -e .[all] | Out-Null + Write-Host "Installing project deps (.[all])" -ForegroundColor Blue + & $venvPython -m pip install .[all] | Out-Null $depsOk = $LASTEXITCODE -eq 0 if (-not $depsOk) { Write-CheckResult "Install deps" $false; exit 1 } else { Write-CheckResult "Install deps" $true } - & $venvPython -m pip install types-psycopg2 aiohttp | Out-Null - Write-CheckResult "Extra deps (types-psycopg2,aiohttp)" ($LASTEXITCODE -eq 0) - # Ruff Write-Host "Running Ruff lint" -ForegroundColor Blue - if ($Verbose) { & $venvPython -m ruff check ./src ./tests } else { & $venvPython -m ruff check ./src ./tests *> $null } + if ($Verbose) { & $venvPython -m ruff check src tests } else { & $venvPython -m ruff check src tests *> $null } Write-CheckResult "ruff lint" ($LASTEXITCODE -eq 0) - # mypy all - if ($Verbose) { & $venvPython -m mypy ./src/azurepg_entra/ } else { & $venvPython -m mypy ./src/azurepg_entra/ *> $null } - Write-CheckResult "mypy (all)" ($LASTEXITCODE -eq 0) + # mypy + Write-Host "Running mypy type check" -ForegroundColor Blue + if ($Verbose) { & $venvPython -m mypy src/azurepg_entra/ } else { & $venvPython -m mypy src/azurepg_entra/ *> $null } + Write-CheckResult "mypy" ($LASTEXITCODE -eq 0) if (Test-Path "tests") { Write-Host "Running pytest" -ForegroundColor Blue @@ -110,9 +108,6 @@ try { } else { Write-Host "WARN No tests directory present" -ForegroundColor Yellow } - - & $venvPython -c "import sys; sys.path.insert(0, 'src'); import azurepg_entra, azurepg_entra.core" 2>$null - Write-CheckResult "Import validation" ($LASTEXITCODE -eq 0) } finally { Pop-Location