diff --git a/.gitignore b/.gitignore index f4cda9a6..c5628ff2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ build/ dist/ experiment_data/ +*.egg-info/ ## Ignore Visual Studio temporary files, build results, and ## files generated by popular Visual Studio add-ons. diff --git a/py-src/data_formulator/data_loader/README.md b/py-src/data_formulator/data_loader/README.md index 70079ec9..2c059905 100644 --- a/py-src/data_formulator/data_loader/README.md +++ b/py-src/data_formulator/data_loader/README.md @@ -35,6 +35,8 @@ The UI uses the same loaders for connection setup, table listing, and ingestion - **`MySQLDataLoader`**: MySQL (connectorx) - **`PostgreSQLDataLoader`**: PostgreSQL (connectorx) - **`MSSQLDataLoader`**: Microsoft SQL Server (connectorx) +- **`AzureSQLDataLoader`**: Azure SQL Database / Managed Instance — SQL auth (connectorx) or Entra ID (pyodbc + AAD token) +- **`FabricLakehouseDataLoader`**: Microsoft Fabric Lakehouse / Data Warehouse SQL Analytics Endpoint — Entra ID only (pyodbc + AAD token) - **`S3DataLoader`**: Amazon S3 files (CSV, Parquet, JSON) via PyArrow S3 filesystem - **`AzureBlobDataLoader`**: Azure Blob Storage via PyArrow - **`MongoDBDataLoader`**: MongoDB diff --git a/py-src/data_formulator/data_loader/__init__.py b/py-src/data_formulator/data_loader/__init__.py index 898c50f5..e12ecef1 100644 --- a/py-src/data_formulator/data_loader/__init__.py +++ b/py-src/data_formulator/data_loader/__init__.py @@ -8,6 +8,8 @@ from data_formulator.data_loader.mongodb_data_loader import MongoDBDataLoader from data_formulator.data_loader.bigquery_data_loader import BigQueryDataLoader from data_formulator.data_loader.athena_data_loader import AthenaDataLoader +from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader +from data_formulator.data_loader.fabric_lakehouse_data_loader import FabricLakehouseDataLoader DATA_LOADERS = { "mysql": MySQLDataLoader, @@ -18,7 +20,9 @@ "postgresql": PostgreSQLDataLoader, "mongodb": MongoDBDataLoader, "bigquery": BigQueryDataLoader, - "athena": AthenaDataLoader + "athena": AthenaDataLoader, + "azure_sql": AzureSQLDataLoader, + "fabric_lakehouse": FabricLakehouseDataLoader, } __all__ = [ @@ -31,5 +35,7 @@ "PostgreSQLDataLoader", "MongoDBDataLoader", "BigQueryDataLoader", - "AthenaDataLoader", + "AthenaDataLoader", + "AzureSQLDataLoader", + "FabricLakehouseDataLoader", "DATA_LOADERS"] \ No newline at end of file diff --git a/py-src/data_formulator/data_loader/azure_sql_data_loader.py b/py-src/data_formulator/data_loader/azure_sql_data_loader.py new file mode 100644 index 00000000..aa9010ca --- /dev/null +++ b/py-src/data_formulator/data_loader/azure_sql_data_loader.py @@ -0,0 +1,431 @@ +import json +import logging +import struct +from typing import Any + +import pandas as pd +import pyarrow as pa + +from data_formulator.data_loader.external_data_loader import ExternalDataLoader + +log = logging.getLogger(__name__) + +# Azure SQL resource for token-based authentication +_AZURE_SQL_RESOURCE = "https://database.windows.net//.default" + + +def _esc(value: str) -> str: + """Escape a string value for safe embedding in a T-SQL string literal. + + Replaces each single-quote with two single-quotes, which is the standard + T-SQL method for escaping string literals. Values are sourced from + INFORMATION_SCHEMA (database object names), so this is an extra safety + measure against edge-case object names. + """ + return value.replace("'", "''") + +# SQL Server types that connectorx cannot handle natively +_CX_SPATIAL_TYPES = {"geometry", "geography"} +_CX_OTHER_UNSUPPORTED = {"hierarchyid", "xml", "sql_variant", "image", "timestamp"} +_CX_UNSUPPORTED_TYPES = _CX_SPATIAL_TYPES | _CX_OTHER_UNSUPPORTED + +# Attribute ID for injecting an AAD access token into an ODBC connection +_SQL_COPT_SS_ACCESS_TOKEN = 1256 + + +def _token_bytes(access_token: str) -> bytes: + """Encode an AAD access token as a length-prefixed UTF-16-LE byte string + expected by the SQL Server ODBC driver (SQL_COPT_SS_ACCESS_TOKEN).""" + encoded = access_token.encode("UTF-16-LE") + return struct.pack(f" list[dict[str, Any]]: + return [ + { + "name": "server", + "type": "string", + "required": True, + "default": "", + "description": "Azure SQL server hostname (e.g. myserver.database.windows.net)", + }, + { + "name": "database", + "type": "string", + "required": True, + "default": "", + "description": "Database name to connect to", + }, + { + "name": "user", + "type": "string", + "required": False, + "default": "", + "description": "SQL username (leave empty to use Entra ID authentication)", + }, + { + "name": "password", + "type": "string", + "required": False, + "default": "", + "description": "SQL password (leave empty to use Entra ID authentication)", + }, + { + "name": "port", + "type": "string", + "required": False, + "default": "1433", + "description": "TCP port (default: 1433)", + }, + { + "name": "client_id", + "type": "string", + "required": False, + "default": "", + "description": "Entra ID application (client) ID for service principal auth", + }, + { + "name": "client_secret", + "type": "string", + "required": False, + "default": "", + "description": "Entra ID client secret for service principal auth", + }, + { + "name": "tenant_id", + "type": "string", + "required": False, + "default": "", + "description": "Entra ID tenant ID for service principal auth", + }, + { + "name": "driver", + "type": "string", + "required": False, + "default": "ODBC Driver 18 for SQL Server", + "description": "ODBC driver name used for Entra ID authentication", + }, + ] + + @staticmethod + def auth_instructions() -> str: + return """**Example (SQL auth):** server: `myserver.database.windows.net` · database: `mydb` · user: `myuser` · password: `MyP@ss` + +**Example (Service Principal):** server: `myserver.database.windows.net` · database: `mydb` · client_id: `abc-123...` · client_secret: `xyz...` · tenant_id: `def-456...` + +**Example (Azure CLI):** server: `myserver.database.windows.net` · database: `mydb` (run `az login` first, leave all credential fields empty) + +**Authentication Options:** +- **SQL Auth:** Provide `user` and `password`. Uses connectorx for high-speed Arrow-native reads. +- **Service Principal (Entra ID):** Provide `client_id`, `client_secret`, and `tenant_id`. The service principal must have at least the `db_datareader` role on the target database. +- **Azure CLI / DefaultAzureCredential:** Run `az login` in your terminal. Leave user, password, client_id, client_secret, and tenant_id empty. + +**Prerequisites (Entra ID auth):** ODBC Driver 17 or 18 for SQL Server must be installed. See [Microsoft Docs](https://learn.microsoft.com/sql/connect/odbc/download-odbc-driver-for-sql-server).""" + + def __init__(self, params: dict[str, Any]): + log.info(f"Initializing AzureSQL DataLoader with parameters: {params}") + self.params = params + + self.server = params.get("server", "").strip() + self.database = params.get("database", "").strip() + self.user = params.get("user", "").strip() + self.password = params.get("password", "").strip() + self.port = params.get("port", "1433").strip() or "1433" + self.client_id = params.get("client_id", "").strip() + self.client_secret = params.get("client_secret", "").strip() + self.tenant_id = params.get("tenant_id", "").strip() + self.driver = ( + params.get("driver", "ODBC Driver 18 for SQL Server").strip() + or "ODBC Driver 18 for SQL Server" + ) + + if not self.server: + raise ValueError("Azure SQL server hostname is required") + if not self.database: + raise ValueError("Database name is required") + + if self.user and self.password: + self._auth_mode = "sql" + server_for_url = ( + "127.0.0.1" + if self.server.lower() == "localhost" + else self.server + ) + self._connection_url = ( + f"mssql://{self.user}:{self.password}" + f"@{server_for_url}:{self.port}/{self.database}" + "?TrustServerCertificate=true" + ) + try: + import connectorx as cx + + cx.read_sql(self._connection_url, "SELECT 1", return_type="arrow") + log.info( + f"Connected to Azure SQL (SQL auth): {self.server}/{self.database}" + ) + except Exception as e: + raise ValueError( + f"Failed to connect to Azure SQL '{self.server}': {e}" + ) from e + else: + self._auth_mode = "entra" + try: + conn = self._get_pyodbc_connection() + conn.close() + log.info( + f"Connected to Azure SQL (Entra auth): {self.server}/{self.database}" + ) + except Exception as e: + raise ValueError( + f"Failed to connect to Azure SQL '{self.server}' with Entra auth: {e}" + ) from e + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_access_token(self) -> str: + """Obtain an AAD access token for Azure SQL Database.""" + from azure.identity import ClientSecretCredential, DefaultAzureCredential + + if self.client_id and self.client_secret and self.tenant_id: + credential = ClientSecretCredential( + tenant_id=self.tenant_id, + client_id=self.client_id, + client_secret=self.client_secret, + ) + else: + credential = DefaultAzureCredential() + + return credential.get_token(_AZURE_SQL_RESOURCE).token + + def _get_pyodbc_connection(self): + """Create a pyodbc connection using token-based Entra authentication.""" + import pyodbc + + conn_str = ( + f"DRIVER={{{self.driver}}};" + f"SERVER={self.server},{self.port};" + f"DATABASE={self.database};" + "Encrypt=yes;" + "TrustServerCertificate=no;" + ) + token = _token_bytes(self._get_access_token()) + return pyodbc.connect(conn_str, attrs_before={_SQL_COPT_SS_ACCESS_TOKEN: token}) + + def _execute_query_cx(self, query: str) -> pa.Table: + """Execute a query via connectorx (SQL auth path).""" + import connectorx as cx + + return cx.read_sql(self._connection_url, query, return_type="arrow") + + def _execute_query_pyodbc(self, query: str) -> pa.Table: + """Execute a query via pyodbc (Entra auth path) and return an Arrow table.""" + conn = self._get_pyodbc_connection() + try: + df = pd.read_sql(query, conn) + return pa.Table.from_pandas(df, preserve_index=False) + finally: + conn.close() + + def _execute_query(self, query: str) -> pa.Table: + """Dispatch a query to the appropriate execution backend.""" + if self._auth_mode == "sql": + return self._execute_query_cx(query) + return self._execute_query_pyodbc(query) + + def _safe_select_list(self, schema: str, table_name: str) -> str: + """Build a SELECT column list that casts unsupported types to text. + + Returns ``'*'`` when no unsupported columns are present. + """ + try: + cols_df = self._execute_query( + f""" + SELECT COLUMN_NAME, DATA_TYPE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = '{_esc(schema)}' AND TABLE_NAME = '{_esc(table_name)}' + ORDER BY ORDINAL_POSITION + """ + ).to_pandas() + if not any( + r["DATA_TYPE"].lower() in _CX_UNSUPPORTED_TYPES + for _, r in cols_df.iterrows() + ): + return "*" + parts = [] + for _, r in cols_df.iterrows(): + col, dtype = r["COLUMN_NAME"], r["DATA_TYPE"].lower() + if dtype in _CX_SPATIAL_TYPES: + parts.append(f"[{col}].STAsText() AS [{col}]") + elif dtype in _CX_OTHER_UNSUPPORTED: + parts.append(f"CAST([{col}] AS NVARCHAR(MAX)) AS [{col}]") + else: + parts.append(f"[{col}]") + return ", ".join(parts) + except Exception: + return "*" + + # ------------------------------------------------------------------ + # ExternalDataLoader interface + # ------------------------------------------------------------------ + + def fetch_data_as_arrow( + self, + source_table: str, + size: int = 1000000, + sort_columns: list[str] | None = None, + sort_order: str = "asc", + ) -> pa.Table: + """Fetch data from Azure SQL as a PyArrow Table.""" + if not source_table: + raise ValueError("source_table must be provided") + + if "." in source_table: + schema, table = source_table.split(".", 1) + else: + schema, table = "dbo", source_table + + schema = schema.strip("[]") + table = table.strip("[]") + + col_list = self._safe_select_list(schema, table) + base_query = f"SELECT {col_list} FROM [{schema}].[{table}]" + + order_by_clause = "" + if sort_columns: + direction = "DESC" if sort_order == "desc" else "ASC" + order_by_clause = ( + " ORDER BY " + + ", ".join(f"[{col}] {direction}" for col in sort_columns) + ) + + query = f"SELECT TOP {size} * FROM ({base_query}{order_by_clause}) AS _limited" + + log.info(f"Executing Azure SQL query: {query[:200]}...") + arrow_table = self._execute_query(query) + log.info(f"Fetched {arrow_table.num_rows} rows from Azure SQL") + return arrow_table + + def list_tables(self, table_filter: str | None = None) -> list[dict[str, Any]]: + """List all tables from the Azure SQL database.""" + try: + tables_df = self._execute_query( + """ + SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND TABLE_SCHEMA NOT IN ('sys', 'INFORMATION_SCHEMA') + ORDER BY TABLE_SCHEMA, TABLE_NAME + """ + ).to_pandas() + except Exception as e: + log.error(f"Failed to list tables from Azure SQL: {e}") + return [] + + results = [] + for _, row in tables_df.iterrows(): + schema = row["TABLE_SCHEMA"] + table_name = row["TABLE_NAME"] + table_type = row.get("TABLE_TYPE", "BASE TABLE") + full_name = f"{schema}.{table_name}" + + if table_filter and table_filter.lower() not in full_name.lower(): + continue + + try: + columns_df = self._execute_query( + f""" + SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, + CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = '{_esc(schema)}' AND TABLE_NAME = '{_esc(table_name)}' + ORDER BY ORDINAL_POSITION + """ + ).to_pandas() + + columns = [] + for _, col_row in columns_df.iterrows(): + col_info: dict[str, Any] = { + "name": col_row["COLUMN_NAME"], + "type": col_row["DATA_TYPE"], + "nullable": col_row["IS_NULLABLE"] == "YES", + "default": col_row["COLUMN_DEFAULT"], + } + for field, key in [ + ("CHARACTER_MAXIMUM_LENGTH", "max_length"), + ("NUMERIC_PRECISION", "precision"), + ("NUMERIC_SCALE", "scale"), + ]: + val = col_row[field] + if val is not None and not pd.isna(val): + try: + col_info[key] = int(val) + except (ValueError, TypeError): + pass + columns.append(col_info) + + col_list = self._safe_select_list(schema, table_name) + + sample_rows: list = [] + try: + sample_df = self._execute_query( + f"SELECT TOP 10 {col_list} FROM [{schema}].[{table_name}]" + ).to_pandas() + sample_rows = json.loads( + sample_df.fillna(value=None).to_json( + orient="records", date_format="iso", default_handler=str + ) + ) + except Exception as e: + log.warning(f"Failed to sample table {full_name}: {e}") + + count_df = self._execute_query( + f"SELECT COUNT(*) AS row_count FROM [{schema}].[{table_name}]" + ).to_pandas() + raw_count = count_df.iloc[0]["row_count"] + try: + row_count = 0 if pd.isna(raw_count) else int(raw_count) + except (ValueError, TypeError): + row_count = 0 + + results.append( + { + "name": full_name, + "metadata": { + "row_count": row_count, + "columns": columns, + "sample_rows": sample_rows, + "table_type": table_type, + }, + } + ) + except Exception as e: + log.warning(f"Failed to get metadata for table {full_name}: {e}") + results.append( + { + "name": full_name, + "metadata": { + "row_count": 0, + "columns": [], + "sample_rows": [], + "table_type": table_type, + }, + } + ) + + return results diff --git a/py-src/data_formulator/data_loader/external_data_loader.py b/py-src/data_formulator/data_loader/external_data_loader.py index 420cedd7..a28d295b 100644 --- a/py-src/data_formulator/data_loader/external_data_loader.py +++ b/py-src/data_formulator/data_loader/external_data_loader.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) # Sensitive parameter names that should be excluded from stored metadata -SENSITIVE_PARAMS = {'password', 'api_key', 'secret', 'token', 'access_key', 'secret_key'} +SENSITIVE_PARAMS = {'password', 'api_key', 'secret', 'token', 'access_key', 'secret_key', 'client_secret'} def sanitize_table_name(name_as: str) -> str: diff --git a/py-src/data_formulator/data_loader/fabric_lakehouse_data_loader.py b/py-src/data_formulator/data_loader/fabric_lakehouse_data_loader.py new file mode 100644 index 00000000..1fefac42 --- /dev/null +++ b/py-src/data_formulator/data_loader/fabric_lakehouse_data_loader.py @@ -0,0 +1,378 @@ +import json +import logging +import struct +from typing import Any + +import pandas as pd +import pyarrow as pa + +from data_formulator.data_loader.external_data_loader import ExternalDataLoader + +log = logging.getLogger(__name__) + +# Token resource for Microsoft Fabric SQL Analytics Endpoints +_FABRIC_SQL_RESOURCE = "https://analysis.windows.net/powerbi/api/.default" + +# Attribute ID for injecting an AAD access token into an ODBC connection +_SQL_COPT_SS_ACCESS_TOKEN = 1256 + + +def _esc(value: str) -> str: + """Escape a string value for safe embedding in a T-SQL string literal. + + Replaces each single-quote with two single-quotes, which is the standard + T-SQL method for escaping string literals. Values are sourced from + INFORMATION_SCHEMA (database object names), so this is an extra safety + measure against edge-case object names. + """ + return value.replace("'", "''") + + +def _token_bytes(access_token: str) -> bytes: + """Encode an AAD access token as a length-prefixed UTF-16-LE byte string + expected by the SQL Server ODBC driver (SQL_COPT_SS_ACCESS_TOKEN).""" + encoded = access_token.encode("UTF-16-LE") + return struct.pack(f" list[dict[str, Any]]: + return [ + { + "name": "server", + "type": "string", + "required": True, + "default": "", + "description": ( + "Fabric SQL Analytics Endpoint hostname " + "(e.g. .datawarehouse.fabric.microsoft.com)" + ), + }, + { + "name": "database", + "type": "string", + "required": True, + "default": "", + "description": "Lakehouse or Warehouse name in the Fabric workspace", + }, + { + "name": "client_id", + "type": "string", + "required": False, + "default": "", + "description": "Entra ID application (client) ID for service principal auth", + }, + { + "name": "client_secret", + "type": "string", + "required": False, + "default": "", + "description": "Entra ID client secret for service principal auth", + }, + { + "name": "tenant_id", + "type": "string", + "required": False, + "default": "", + "description": "Entra ID tenant ID for service principal auth", + }, + { + "name": "driver", + "type": "string", + "required": False, + "default": "ODBC Driver 18 for SQL Server", + "description": "ODBC driver name (ODBC Driver 17 or 18 for SQL Server)", + }, + ] + + @staticmethod + def auth_instructions() -> str: + return """**Example (Service Principal):** server: `.datawarehouse.fabric.microsoft.com` · database: `MyLakehouse` · client_id: `abc-123...` · client_secret: `xyz...` · tenant_id: `def-456...` + +**Example (Azure CLI):** server: `.datawarehouse.fabric.microsoft.com` · database: `MyLakehouse` (run `az login` first, leave credential fields empty) + +**How to find your SQL endpoint:** +Open the Fabric portal → select your Lakehouse → click *SQL Analytics Endpoint* → copy the server hostname shown in the connection details. + +**Authentication Options:** +- **Service Principal (Entra ID):** Register an Azure AD application, generate a client secret, and add it as a workspace member with at least the *Viewer* role in the Fabric portal (Workspace settings → Manage access). +- **Azure CLI / DefaultAzureCredential:** Run `az login` in your terminal. Leave client_id, client_secret, and tenant_id empty. + +**Prerequisites:** ODBC Driver 17 or 18 for SQL Server must be installed. See [Microsoft Docs](https://learn.microsoft.com/sql/connect/odbc/download-odbc-driver-for-sql-server).""" + + def __init__(self, params: dict[str, Any]): + log.info(f"Initializing FabricLakehouse DataLoader with parameters: {params}") + self.params = params + + self.server = params.get("server", "").strip() + self.database = params.get("database", "").strip() + self.client_id = params.get("client_id", "").strip() + self.client_secret = params.get("client_secret", "").strip() + self.tenant_id = params.get("tenant_id", "").strip() + self.driver = ( + params.get("driver", "ODBC Driver 18 for SQL Server").strip() + or "ODBC Driver 18 for SQL Server" + ) + + if not self.server: + raise ValueError("Fabric SQL Analytics Endpoint hostname is required") + if not self.database: + raise ValueError("Lakehouse or Warehouse name (database) is required") + + # Verify the connection on initialisation + try: + conn = self._get_pyodbc_connection() + conn.close() + log.info( + f"Connected to Fabric SQL endpoint: {self.server}/{self.database}" + ) + except Exception as e: + raise ValueError( + f"Failed to connect to Fabric SQL endpoint '{self.server}': {e}" + ) from e + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_access_token(self) -> str: + """Obtain an AAD access token for the Fabric SQL Analytics Endpoint.""" + from azure.identity import ClientSecretCredential, DefaultAzureCredential + + if self.client_id and self.client_secret and self.tenant_id: + credential = ClientSecretCredential( + tenant_id=self.tenant_id, + client_id=self.client_id, + client_secret=self.client_secret, + ) + else: + credential = DefaultAzureCredential() + + return credential.get_token(_FABRIC_SQL_RESOURCE).token + + def _get_pyodbc_connection(self): + """Create a pyodbc connection to the Fabric SQL endpoint using AAD token auth.""" + import pyodbc + + conn_str = ( + f"DRIVER={{{self.driver}}};" + f"SERVER={self.server};" + f"DATABASE={self.database};" + "Encrypt=yes;" + "TrustServerCertificate=no;" + ) + token = _token_bytes(self._get_access_token()) + return pyodbc.connect(conn_str, attrs_before={_SQL_COPT_SS_ACCESS_TOKEN: token}) + + def _execute_query(self, query: str) -> pa.Table: + """Execute a T-SQL query against the Fabric endpoint and return an Arrow table.""" + conn = self._get_pyodbc_connection() + try: + df = pd.read_sql(query, conn) + return pa.Table.from_pandas(df, preserve_index=False) + finally: + conn.close() + + def _safe_select_list(self, schema: str, table_name: str) -> str: + """Build a SELECT column list that casts unsupported types to NVARCHAR. + + Returns ``'*'`` when no unsupported columns are present. + """ + _spatial = {"geometry", "geography"} + _other = {"hierarchyid", "xml", "sql_variant", "image", "timestamp"} + _unsupported = _spatial | _other + try: + cols_df = self._execute_query( + f""" + SELECT COLUMN_NAME, DATA_TYPE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = '{_esc(schema)}' AND TABLE_NAME = '{_esc(table_name)}' + ORDER BY ORDINAL_POSITION + """ + ).to_pandas() + if not any( + r["DATA_TYPE"].lower() in _unsupported + for _, r in cols_df.iterrows() + ): + return "*" + parts = [] + for _, r in cols_df.iterrows(): + col, dtype = r["COLUMN_NAME"], r["DATA_TYPE"].lower() + if dtype in _spatial: + parts.append(f"[{col}].STAsText() AS [{col}]") + elif dtype in _other: + parts.append(f"CAST([{col}] AS NVARCHAR(MAX)) AS [{col}]") + else: + parts.append(f"[{col}]") + return ", ".join(parts) + except Exception: + return "*" + + # ------------------------------------------------------------------ + # ExternalDataLoader interface + # ------------------------------------------------------------------ + + def fetch_data_as_arrow( + self, + source_table: str, + size: int = 1000000, + sort_columns: list[str] | None = None, + sort_order: str = "asc", + ) -> pa.Table: + """Fetch data from the Fabric Lakehouse SQL endpoint as a PyArrow Table.""" + if not source_table: + raise ValueError("source_table must be provided") + + if "." in source_table: + schema, table = source_table.split(".", 1) + else: + schema, table = "dbo", source_table + + schema = schema.strip("[]") + table = table.strip("[]") + + col_list = self._safe_select_list(schema, table) + base_query = f"SELECT {col_list} FROM [{schema}].[{table}]" + + order_by_clause = "" + if sort_columns: + direction = "DESC" if sort_order == "desc" else "ASC" + order_by_clause = ( + " ORDER BY " + + ", ".join(f"[{col}] {direction}" for col in sort_columns) + ) + + query = f"SELECT TOP {size} * FROM ({base_query}{order_by_clause}) AS _limited" + + log.info(f"Executing Fabric SQL query: {query[:200]}...") + arrow_table = self._execute_query(query) + log.info(f"Fetched {arrow_table.num_rows} rows from Fabric Lakehouse") + return arrow_table + + def list_tables(self, table_filter: str | None = None) -> list[dict[str, Any]]: + """List all tables from the Fabric Lakehouse SQL Analytics Endpoint.""" + try: + tables_df = self._execute_query( + """ + SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE IN ('BASE TABLE', 'VIEW') + AND TABLE_SCHEMA NOT IN ('sys', 'INFORMATION_SCHEMA') + ORDER BY TABLE_SCHEMA, TABLE_NAME + """ + ).to_pandas() + except Exception as e: + log.error(f"Failed to list tables from Fabric SQL endpoint: {e}") + return [] + + results = [] + for _, row in tables_df.iterrows(): + schema = row["TABLE_SCHEMA"] + table_name = row["TABLE_NAME"] + table_type = row.get("TABLE_TYPE", "BASE TABLE") + full_name = f"{schema}.{table_name}" + + if table_filter and table_filter.lower() not in full_name.lower(): + continue + + try: + columns_df = self._execute_query( + f""" + SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, + CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = '{_esc(schema)}' AND TABLE_NAME = '{_esc(table_name)}' + ORDER BY ORDINAL_POSITION + """ + ).to_pandas() + + columns = [] + for _, col_row in columns_df.iterrows(): + col_info: dict[str, Any] = { + "name": col_row["COLUMN_NAME"], + "type": col_row["DATA_TYPE"], + "nullable": col_row["IS_NULLABLE"] == "YES", + "default": col_row["COLUMN_DEFAULT"], + } + for field, key in [ + ("CHARACTER_MAXIMUM_LENGTH", "max_length"), + ("NUMERIC_PRECISION", "precision"), + ("NUMERIC_SCALE", "scale"), + ]: + val = col_row[field] + if val is not None and not pd.isna(val): + try: + col_info[key] = int(val) + except (ValueError, TypeError): + pass + columns.append(col_info) + + col_list = self._safe_select_list(schema, table_name) + + sample_rows: list = [] + try: + sample_df = self._execute_query( + f"SELECT TOP 10 {col_list} FROM [{schema}].[{table_name}]" + ).to_pandas() + sample_rows = json.loads( + sample_df.fillna(value=None).to_json( + orient="records", date_format="iso", default_handler=str + ) + ) + except Exception as e: + log.warning(f"Failed to sample table {full_name}: {e}") + + count_df = self._execute_query( + f"SELECT COUNT(*) AS row_count FROM [{schema}].[{table_name}]" + ).to_pandas() + raw_count = count_df.iloc[0]["row_count"] + try: + row_count = 0 if pd.isna(raw_count) else int(raw_count) + except (ValueError, TypeError): + row_count = 0 + + results.append( + { + "name": full_name, + "metadata": { + "row_count": row_count, + "columns": columns, + "sample_rows": sample_rows, + "table_type": table_type, + }, + } + ) + except Exception as e: + log.warning(f"Failed to get metadata for table {full_name}: {e}") + results.append( + { + "name": full_name, + "metadata": { + "row_count": 0, + "columns": [], + "sample_rows": [], + "table_type": table_type, + }, + } + ) + + return results diff --git a/py-src/tests/__init__.py b/py-src/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/py-src/tests/test_azure_sql_and_fabric_loaders.py b/py-src/tests/test_azure_sql_and_fabric_loaders.py new file mode 100644 index 00000000..60e22967 --- /dev/null +++ b/py-src/tests/test_azure_sql_and_fabric_loaders.py @@ -0,0 +1,1200 @@ +""" +Tests for AzureSQLDataLoader and FabricLakehouseDataLoader. + +All network/driver calls are replaced with in-process mocks so the tests run +without any real Azure SQL or Fabric endpoints, ODBC drivers, or AAD credentials. + +Simulated environment: + - Database schema: two tables (sales.orders, sales.customers) + - orders: id (int), amount (float), status (nvarchar) — 3 sample rows + - customers: id (int), name (nvarchar), city (nvarchar) — 2 sample rows + +The mocks intercept: + - connectorx.read_sql → used by AzureSQLDataLoader SQL-auth path + - pyodbc.connect → used by both loaders for Entra-ID auth + - azure.identity.*Credential.get_token → returns a fake AAD token +""" + +from __future__ import annotations + +import json +import struct +import types +from unittest.mock import MagicMock, patch, PropertyMock +from typing import Any + +import pandas as pd +import pyarrow as pa +import pytest + +# --------------------------------------------------------------------------- +# Shared test data +# --------------------------------------------------------------------------- + +_TABLES_DF = pd.DataFrame( + { + "TABLE_SCHEMA": ["sales", "sales"], + "TABLE_NAME": ["orders", "customers"], + "TABLE_TYPE": ["BASE TABLE", "BASE TABLE"], + } +) + +_ORDERS_COLS_DF = pd.DataFrame( + { + "COLUMN_NAME": ["id", "amount", "status"], + "DATA_TYPE": ["int", "float", "nvarchar"], + "IS_NULLABLE": ["NO", "YES", "YES"], + "COLUMN_DEFAULT": [None, None, None], + "CHARACTER_MAXIMUM_LENGTH": [None, None, 50], + "NUMERIC_PRECISION": [10, 15, None], + "NUMERIC_SCALE": [0, 2, None], + } +) + +_CUSTOMERS_COLS_DF = pd.DataFrame( + { + "COLUMN_NAME": ["id", "name", "city"], + "DATA_TYPE": ["int", "nvarchar", "nvarchar"], + "IS_NULLABLE": ["NO", "YES", "YES"], + "COLUMN_DEFAULT": [None, None, None], + "CHARACTER_MAXIMUM_LENGTH": [None, 100, 50], + "NUMERIC_PRECISION": [10, None, None], + "NUMERIC_SCALE": [0, None, None], + } +) + +_ORDERS_DATA_DF = pd.DataFrame( + { + "id": [1, 2, 3], + "amount": [99.99, 149.0, 299.50], + "status": ["shipped", "pending", "delivered"], + } +) + +_CUSTOMERS_DATA_DF = pd.DataFrame( + { + "id": [1, 2], + "name": ["Alice", "Bob"], + "city": ["Seattle", "New York"], + } +) + +_COUNT_ORDERS_DF = pd.DataFrame({"row_count": [3]}) +_COUNT_CUSTOMERS_DF = pd.DataFrame({"row_count": [2]}) + + +# --------------------------------------------------------------------------- +# Query router: dispatches query strings to mock DataFrames +# --------------------------------------------------------------------------- + +def _query_router(sql: str) -> pd.DataFrame: + """Map a SQL query string to a simulated result DataFrame.""" + sql_upper = sql.strip().upper() + + # Connection probe + if sql_upper == "SELECT 1": + return pd.DataFrame({"col": [1]}) + + # Table listing + if "INFORMATION_SCHEMA.TABLES" in sql_upper: + return _TABLES_DF + + # Column metadata + if "INFORMATION_SCHEMA.COLUMNS" in sql_upper: + if "'ORDERS'" in sql_upper or "ORDERS" in sql_upper: + return _ORDERS_COLS_DF + return _CUSTOMERS_COLS_DF + + # Row count + if "COUNT(*)" in sql_upper: + if "ORDERS" in sql_upper: + return _COUNT_ORDERS_DF + return _COUNT_CUSTOMERS_DF + + # Sample / full data fetch + if "ORDERS" in sql_upper: + return _ORDERS_DATA_DF + return _CUSTOMERS_DATA_DF + + +def _make_arrow(sql: str) -> pa.Table: + """Return a PyArrow Table for the given SQL query.""" + return pa.Table.from_pandas(_query_router(sql), preserve_index=False) + + +# --------------------------------------------------------------------------- +# Fixture factories +# --------------------------------------------------------------------------- + +def _make_cx_mock(): + """Mock connectorx.read_sql to route queries via _make_arrow.""" + cx_mock = MagicMock() + cx_mock.read_sql.side_effect = lambda url, query, **kw: _make_arrow(query) + return cx_mock + + +def _make_pyodbc_mock(): + """Mock pyodbc.connect; the returned connection routes pd.read_sql queries.""" + conn = MagicMock() + # pyodbc connections are used as context managers for pd.read_sql + conn.__enter__ = MagicMock(return_value=conn) + conn.__exit__ = MagicMock(return_value=False) + conn.close = MagicMock() + return conn + + +def _make_token_mock(): + """Return a mock azure.identity credential whose get_token() always succeeds.""" + token = MagicMock() + token.token = "fake-aad-token-for-testing" + cred = MagicMock() + cred.get_token.return_value = token + return cred + + +# --------------------------------------------------------------------------- +# Helper: build a loader instance without real I/O +# --------------------------------------------------------------------------- + +def _build_azure_sql_loader_sql_auth() -> "AzureSQLDataLoader": + """Instantiate AzureSQLDataLoader using SQL authentication with cx mocked.""" + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + cx_mock = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader = AzureSQLDataLoader( + { + "server": "myserver.database.windows.net", + "database": "mydb", + "user": "sa", + "password": "MyP@ss", + "port": "1433", + } + ) + # Attach the cx mock so we can call methods later + loader._cx_mock = cx_mock + return loader + + +def _build_azure_sql_loader_entra(cred_mock=None) -> "AzureSQLDataLoader": + """Instantiate AzureSQLDataLoader using Entra ID authentication.""" + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + if cred_mock is None: + cred_mock = _make_token_mock() + + pyodbc_mock = MagicMock() + conn = _make_pyodbc_mock() + pyodbc_mock.connect.return_value = conn + + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc_mock}), + patch( + "data_formulator.data_loader.azure_sql_data_loader.AzureSQLDataLoader._get_access_token", + return_value="fake-token", + ), + ): + loader = AzureSQLDataLoader( + { + "server": "myserver.database.windows.net", + "database": "mydb", + "client_id": "cid", + "client_secret": "csec", + "tenant_id": "tid", + } + ) + loader._pyodbc_mock = pyodbc_mock + return loader + + +def _build_fabric_loader(cred_mock=None) -> "FabricLakehouseDataLoader": + """Instantiate FabricLakehouseDataLoader with all I/O mocked.""" + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + if cred_mock is None: + cred_mock = _make_token_mock() + + pyodbc_mock = MagicMock() + conn = _make_pyodbc_mock() + pyodbc_mock.connect.return_value = conn + + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + ): + loader = FabricLakehouseDataLoader( + { + "server": "myworkspace.datawarehouse.fabric.microsoft.com", + "database": "MyLakehouse", + "client_id": "cid", + "client_secret": "csec", + "tenant_id": "tid", + } + ) + loader._pyodbc_mock = pyodbc_mock + return loader + + +# =========================================================================== +# Tests: module-level helpers +# =========================================================================== + + +class TestEscHelper: + def test_no_quotes(self): + from data_formulator.data_loader.azure_sql_data_loader import _esc + + assert _esc("myschema") == "myschema" + + def test_single_quote_escaped(self): + from data_formulator.data_loader.azure_sql_data_loader import _esc + + assert _esc("o'reilly") == "o''reilly" + + def test_multiple_quotes(self): + from data_formulator.data_loader.azure_sql_data_loader import _esc + + assert _esc("it's a test's value") == "it''s a test''s value" + + def test_fabric_esc_same_behaviour(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import _esc + + assert _esc("it's") == "it''s" + + +class TestTokenBytes: + def test_structure(self): + from data_formulator.data_loader.azure_sql_data_loader import _token_bytes + + result = _token_bytes("hello") + encoded = "hello".encode("UTF-16-LE") + length = struct.unpack_from(" 50 + + def test_registered_in_data_loaders(self): + from data_formulator.data_loader import DATA_LOADERS, AzureSQLDataLoader + + assert "azure_sql" in DATA_LOADERS + assert DATA_LOADERS["azure_sql"] is AzureSQLDataLoader + + +class TestFabricLakehouseStaticMethods: + def test_list_params_names(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + params = FabricLakehouseDataLoader.list_params() + names = [p["name"] for p in params] + assert "server" in names + assert "database" in names + assert "client_id" in names + assert "client_secret" in names + assert "tenant_id" in names + + def test_no_user_password_params(self): + """Fabric does not support SQL auth; user/password should not appear.""" + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + names = [p["name"] for p in FabricLakehouseDataLoader.list_params()] + assert "user" not in names + assert "password" not in names + + def test_server_and_database_are_required(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + by_name = {p["name"]: p for p in FabricLakehouseDataLoader.list_params()} + assert by_name["server"]["required"] is True + assert by_name["database"]["required"] is True + + def test_auth_instructions_nonempty(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + instr = FabricLakehouseDataLoader.auth_instructions() + assert isinstance(instr, str) and len(instr) > 50 + + def test_registered_in_data_loaders(self): + from data_formulator.data_loader import DATA_LOADERS, FabricLakehouseDataLoader + + assert "fabric_lakehouse" in DATA_LOADERS + assert DATA_LOADERS["fabric_lakehouse"] is FabricLakehouseDataLoader + + +# =========================================================================== +# Tests: AzureSQLDataLoader — init validation +# =========================================================================== + + +class TestAzureSQLInitValidation: + def test_missing_server_raises(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + with pytest.raises(ValueError, match="server"): + cx = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx}): + AzureSQLDataLoader({"server": "", "database": "mydb", "user": "u", "password": "p"}) + + def test_missing_database_raises(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + with pytest.raises(ValueError, match="[Dd]atabase"): + cx = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx}): + AzureSQLDataLoader({"server": "srv", "database": "", "user": "u", "password": "p"}) + + def test_connection_failure_wraps_error(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + cx = MagicMock() + cx.read_sql.side_effect = RuntimeError("connection refused") + with pytest.raises(ValueError, match="Failed to connect"): + with patch.dict("sys.modules", {"connectorx": cx}): + AzureSQLDataLoader( + {"server": "bad-srv", "database": "db", "user": "u", "password": "p"} + ) + + def test_entra_connection_failure_wraps_error(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + pyodbc = MagicMock() + pyodbc.connect.side_effect = RuntimeError("ODBC error") + with pytest.raises(ValueError, match="Failed to connect"): + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc}), + patch( + "data_formulator.data_loader.azure_sql_data_loader.AzureSQLDataLoader._get_access_token", + return_value="tok", + ), + ): + AzureSQLDataLoader( + { + "server": "srv", + "database": "db", + "client_id": "c", + "client_secret": "s", + "tenant_id": "t", + } + ) + + +# =========================================================================== +# Tests: FabricLakehouseDataLoader — init validation +# =========================================================================== + + +class TestFabricLakehouseInitValidation: + def test_missing_server_raises(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + with pytest.raises(ValueError, match="[Ss]erver|[Ee]ndpoint"): + pyodbc = MagicMock() + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="tok", + ), + ): + FabricLakehouseDataLoader({"server": "", "database": "MyLH"}) + + def test_missing_database_raises(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + with pytest.raises(ValueError, match="[Dd]atabase|[Ll]akehouse|[Ww]arehouse"): + pyodbc = MagicMock() + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="tok", + ), + ): + FabricLakehouseDataLoader({"server": "srv.fabric.microsoft.com", "database": ""}) + + def test_connection_failure_wraps_error(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + pyodbc = MagicMock() + pyodbc.connect.side_effect = RuntimeError("ODBC driver not found") + with pytest.raises(ValueError, match="Failed to connect"): + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="tok", + ), + ): + FabricLakehouseDataLoader( + {"server": "srv.fabric.microsoft.com", "database": "MyLH"} + ) + + +# =========================================================================== +# Tests: AzureSQLDataLoader — fetch_data_as_arrow (SQL auth path) +# =========================================================================== + + +class TestAzureSQLFetchData: + def _make_loader(self) -> "AzureSQLDataLoader": + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + cx_mock = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader = AzureSQLDataLoader( + { + "server": "myserver.database.windows.net", + "database": "mydb", + "user": "sa", + "password": "pass", + } + ) + loader._cx_mock = cx_mock + return loader + + def test_returns_arrow_table(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + result = loader.fetch_data_as_arrow("sales.orders") + assert isinstance(result, pa.Table) + assert result.num_rows == 3 + + def test_column_names(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + result = loader.fetch_data_as_arrow("sales.orders") + assert set(result.schema.names) == {"id", "amount", "status"} + + def test_table_without_schema_defaults_to_dbo(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + result = loader.fetch_data_as_arrow("orders") + # Should work — query should include dbo as schema + assert isinstance(result, pa.Table) + + def test_empty_source_table_raises(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with pytest.raises(ValueError, match="source_table"): + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader.fetch_data_as_arrow("") + + def test_size_limit_in_query(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + queries_seen: list[str] = [] + cx_mock.read_sql.side_effect = lambda url, q, **kw: ( + queries_seen.append(q) or _make_arrow(q) + ) + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader.fetch_data_as_arrow("sales.orders", size=50) + # The last query executed for data fetch should contain TOP 50 + data_queries = [q for q in queries_seen if "TOP 50" in q.upper() or "top 50" in q.lower()] + assert len(data_queries) >= 1 + + def test_sort_columns_asc(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + queries_seen: list[str] = [] + cx_mock.read_sql.side_effect = lambda url, q, **kw: ( + queries_seen.append(q) or _make_arrow(q) + ) + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader.fetch_data_as_arrow("sales.orders", sort_columns=["amount"]) + data_queries = [q for q in queries_seen if "ORDER BY" in q.upper()] + assert any("ASC" in q.upper() for q in data_queries) + + def test_sort_columns_desc(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + queries_seen: list[str] = [] + cx_mock.read_sql.side_effect = lambda url, q, **kw: ( + queries_seen.append(q) or _make_arrow(q) + ) + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader.fetch_data_as_arrow("sales.orders", sort_columns=["id"], sort_order="desc") + data_queries = [q for q in queries_seen if "ORDER BY" in q.upper()] + assert any("DESC" in q.upper() for q in data_queries) + + +# =========================================================================== +# Tests: AzureSQLDataLoader — list_tables (SQL auth path) +# =========================================================================== + + +class TestAzureSQLListTables: + def _make_loader(self) -> "AzureSQLDataLoader": + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + cx_mock = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader = AzureSQLDataLoader( + { + "server": "myserver.database.windows.net", + "database": "mydb", + "user": "sa", + "password": "pass", + } + ) + loader._cx_mock = cx_mock + return loader + + def test_returns_list(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables() + assert isinstance(tables, list) + assert len(tables) == 2 + + def test_table_names(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables() + names = [t["name"] for t in tables] + assert "sales.orders" in names + assert "sales.customers" in names + + def test_metadata_shape(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables() + for t in tables: + md = t["metadata"] + assert "row_count" in md + assert "columns" in md + assert "sample_rows" in md + assert "table_type" in md + + def test_orders_row_count(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables() + orders = next(t for t in tables if t["name"] == "sales.orders") + assert orders["metadata"]["row_count"] == 3 + + def test_orders_columns(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables() + orders = next(t for t in tables if t["name"] == "sales.orders") + col_names = [c["name"] for c in orders["metadata"]["columns"]] + assert "id" in col_names + assert "amount" in col_names + assert "status" in col_names + + def test_table_filter(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables(table_filter="orders") + assert len(tables) == 1 + assert tables[0]["name"] == "sales.orders" + + def test_table_filter_no_match(self): + loader = self._make_loader() + cx_mock = loader._cx_mock + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables(table_filter="nonexistent_xyz") + assert tables == [] + + def test_list_tables_returns_empty_on_connection_error(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + cx_mock = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx_mock}): + loader = AzureSQLDataLoader( + {"server": "srv", "database": "db", "user": "u", "password": "p"} + ) + # Now break the connection + cx_mock.read_sql.side_effect = RuntimeError("connection lost") + with patch.dict("sys.modules", {"connectorx": cx_mock}): + tables = loader.list_tables() + assert tables == [] + + +# =========================================================================== +# Tests: AzureSQLDataLoader — Entra ID auth path +# =========================================================================== + + +class TestAzureSQLEntraAuth: + def _make_loader_and_conn(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + pyodbc_mock = MagicMock() + conn = _make_pyodbc_mock() + pyodbc_mock.connect.return_value = conn + + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc_mock}), + patch( + "data_formulator.data_loader.azure_sql_data_loader.AzureSQLDataLoader._get_access_token", + return_value="fake-token", + ), + ): + loader = AzureSQLDataLoader( + { + "server": "myserver.database.windows.net", + "database": "mydb", + "client_id": "cid", + "client_secret": "csec", + "tenant_id": "tid", + } + ) + return loader, pyodbc_mock, conn + + def test_auth_mode_is_entra(self): + loader, _, _ = self._make_loader_and_conn() + assert loader._auth_mode == "entra" + + def test_fetch_data_via_pyodbc(self): + loader, pyodbc_mock, conn = self._make_loader_and_conn() + # Make pd.read_sql return orders data when called with a pyodbc connection + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc_mock}), + patch( + "data_formulator.data_loader.azure_sql_data_loader.AzureSQLDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: _query_router(q)), + ): + result = loader.fetch_data_as_arrow("sales.orders") + assert isinstance(result, pa.Table) + + def test_get_safe_params_hides_secret(self): + loader, _, _ = self._make_loader_and_conn() + safe = loader.get_safe_params() + assert "client_secret" not in safe + assert "server" in safe + assert "database" in safe + + def test_service_principal_credential_used(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + pyodbc_mock = MagicMock() + conn = _make_pyodbc_mock() + pyodbc_mock.connect.return_value = conn + + client_cred_cls = MagicMock() + token_obj = MagicMock() + token_obj.token = "sp-token" + client_cred_cls.return_value.get_token.return_value = token_obj + + azure_identity_mock = MagicMock() + azure_identity_mock.ClientSecretCredential = client_cred_cls + azure_identity_mock.DefaultAzureCredential = MagicMock() + + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc_mock, "azure.identity": azure_identity_mock}), + patch("data_formulator.data_loader.azure_sql_data_loader.AzureSQLDataLoader.__init__", lambda s, p: None), + ): + loader = AzureSQLDataLoader.__new__(AzureSQLDataLoader) + loader.client_id = "cid" + loader.client_secret = "csec" + loader.tenant_id = "tid" + + # Call the real _get_access_token via the module's actual code path + import data_formulator.data_loader.azure_sql_data_loader as mod + original_fn = mod.AzureSQLDataLoader._get_access_token + + with patch.dict("sys.modules", {"azure.identity": azure_identity_mock}): + token = original_fn(loader) + + client_cred_cls.assert_called_once_with( + tenant_id="tid", client_id="cid", client_secret="csec" + ) + assert token == "sp-token" + + def test_default_credential_used_when_no_sp(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + default_cred_cls = MagicMock() + token_obj = MagicMock() + token_obj.token = "cli-token" + default_cred_cls.return_value.get_token.return_value = token_obj + + azure_identity_mock = MagicMock() + azure_identity_mock.ClientSecretCredential = MagicMock() + azure_identity_mock.DefaultAzureCredential = default_cred_cls + + with ( + patch("data_formulator.data_loader.azure_sql_data_loader.AzureSQLDataLoader.__init__", lambda s, p: None), + ): + loader = AzureSQLDataLoader.__new__(AzureSQLDataLoader) + loader.client_id = "" + loader.client_secret = "" + loader.tenant_id = "" + + import data_formulator.data_loader.azure_sql_data_loader as mod + + with patch.dict("sys.modules", {"azure.identity": azure_identity_mock}): + token = mod.AzureSQLDataLoader._get_access_token(loader) + + default_cred_cls.assert_called_once() + assert token == "cli-token" + + +# =========================================================================== +# Tests: FabricLakehouseDataLoader — fetch_data_as_arrow +# =========================================================================== + + +class TestFabricFetchData: + def _make_loader(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + pyodbc_mock = MagicMock() + conn = _make_pyodbc_mock() + pyodbc_mock.connect.return_value = conn + + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + ): + loader = FabricLakehouseDataLoader( + { + "server": "ws.datawarehouse.fabric.microsoft.com", + "database": "MyLakehouse", + "client_id": "cid", + "client_secret": "csec", + "tenant_id": "tid", + } + ) + loader._pyodbc_mock = pyodbc_mock + return loader + + def test_returns_arrow_table(self): + loader = self._make_loader() + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: _query_router(q)), + ): + result = loader.fetch_data_as_arrow("sales.orders") + assert isinstance(result, pa.Table) + assert result.num_rows == 3 + + def test_empty_source_table_raises(self): + loader = self._make_loader() + with pytest.raises(ValueError, match="source_table"): + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: _query_router(q)), + ): + loader.fetch_data_as_arrow("") + + def test_size_limit_in_query(self): + loader = self._make_loader() + queries_seen: list[str] = [] + + def capturing_read_sql(q, c): + queries_seen.append(q) + return _query_router(q) + + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=capturing_read_sql), + ): + loader.fetch_data_as_arrow("sales.orders", size=25) + + data_queries = [q for q in queries_seen if "TOP 25" in q.upper() or "top 25" in q.lower()] + assert len(data_queries) >= 1 + + def test_sort_order_desc(self): + loader = self._make_loader() + queries_seen: list[str] = [] + + def capturing_read_sql(q, c): + queries_seen.append(q) + return _query_router(q) + + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=capturing_read_sql), + ): + loader.fetch_data_as_arrow("sales.orders", sort_columns=["id"], sort_order="desc") + + order_queries = [q for q in queries_seen if "ORDER BY" in q.upper()] + assert any("DESC" in q.upper() for q in order_queries) + + +# =========================================================================== +# Tests: FabricLakehouseDataLoader — list_tables +# =========================================================================== + + +class TestFabricListTables: + def _make_loader(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + pyodbc_mock = MagicMock() + conn = _make_pyodbc_mock() + pyodbc_mock.connect.return_value = conn + + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + ): + loader = FabricLakehouseDataLoader( + { + "server": "ws.datawarehouse.fabric.microsoft.com", + "database": "MyLakehouse", + } + ) + loader._pyodbc_mock = pyodbc_mock + return loader + + def test_returns_list(self): + loader = self._make_loader() + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: _query_router(q)), + ): + tables = loader.list_tables() + assert isinstance(tables, list) + assert len(tables) == 2 + + def test_table_names(self): + loader = self._make_loader() + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: _query_router(q)), + ): + tables = loader.list_tables() + names = [t["name"] for t in tables] + assert "sales.orders" in names + assert "sales.customers" in names + + def test_metadata_shape(self): + loader = self._make_loader() + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: _query_router(q)), + ): + tables = loader.list_tables() + for t in tables: + md = t["metadata"] + assert "row_count" in md + assert "columns" in md + assert "sample_rows" in md + assert "table_type" in md + + def test_table_filter(self): + loader = self._make_loader() + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: _query_router(q)), + ): + tables = loader.list_tables(table_filter="customers") + assert len(tables) == 1 + assert tables[0]["name"] == "sales.customers" + + def test_list_tables_empty_on_connection_error(self): + loader = self._make_loader() + broken_pyodbc = MagicMock() + broken_pyodbc.connect.side_effect = RuntimeError("ODBC error") + with ( + patch.dict("sys.modules", {"pyodbc": broken_pyodbc}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + ): + tables = loader.list_tables() + assert tables == [] + + def test_includes_views_in_table_listing(self): + """Fabric returns both BASE TABLE and VIEW rows — both should be listed.""" + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + tables_with_view = pd.DataFrame( + { + "TABLE_SCHEMA": ["dbo", "dbo"], + "TABLE_NAME": ["delta_table", "report_view"], + "TABLE_TYPE": ["BASE TABLE", "VIEW"], + } + ) + + def router_with_view(q: str) -> pd.DataFrame: + if "INFORMATION_SCHEMA.TABLES" in q.upper(): + return tables_with_view + if "INFORMATION_SCHEMA.COLUMNS" in q.upper(): + return pd.DataFrame( + { + "COLUMN_NAME": ["id"], + "DATA_TYPE": ["int"], + "IS_NULLABLE": ["NO"], + "COLUMN_DEFAULT": [None], + "CHARACTER_MAXIMUM_LENGTH": [None], + "NUMERIC_PRECISION": [10], + "NUMERIC_SCALE": [0], + } + ) + if "COUNT(*)" in q.upper(): + return pd.DataFrame({"row_count": [100]}) + return pd.DataFrame({"id": [1, 2]}) + + loader = self._make_loader() + with ( + patch.dict("sys.modules", {"pyodbc": loader._pyodbc_mock}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="fake-token", + ), + patch("pandas.read_sql", side_effect=lambda q, c: router_with_view(q)), + ): + tables = loader.list_tables() + names = [t["name"] for t in tables] + assert "dbo.delta_table" in names + assert "dbo.report_view" in names + + +# =========================================================================== +# Tests: safe_params — sensitive fields removed +# =========================================================================== + + +class TestGetSafeParams: + def test_azure_sql_hides_password(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + cx = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx}): + loader = AzureSQLDataLoader( + { + "server": "srv", + "database": "db", + "user": "u", + "password": "super-secret", + } + ) + safe = loader.get_safe_params() + assert "password" not in safe + assert safe.get("server") == "srv" + + def test_azure_sql_hides_client_secret(self): + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + pyodbc = MagicMock() + conn = _make_pyodbc_mock() + pyodbc.connect.return_value = conn + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc}), + patch( + "data_formulator.data_loader.azure_sql_data_loader.AzureSQLDataLoader._get_access_token", + return_value="tok", + ), + ): + loader = AzureSQLDataLoader( + { + "server": "srv", + "database": "db", + "client_id": "cid", + "client_secret": "my-secret", + "tenant_id": "tid", + } + ) + safe = loader.get_safe_params() + assert "client_secret" not in safe + assert safe.get("client_id") == "cid" + + def test_fabric_hides_client_secret(self): + from data_formulator.data_loader.fabric_lakehouse_data_loader import ( + FabricLakehouseDataLoader, + ) + + pyodbc = MagicMock() + conn = _make_pyodbc_mock() + pyodbc.connect.return_value = conn + with ( + patch.dict("sys.modules", {"pyodbc": pyodbc}), + patch( + "data_formulator.data_loader.fabric_lakehouse_data_loader.FabricLakehouseDataLoader._get_access_token", + return_value="tok", + ), + ): + loader = FabricLakehouseDataLoader( + { + "server": "srv.fabric.microsoft.com", + "database": "MyLH", + "client_id": "cid", + "client_secret": "my-secret", + "tenant_id": "tid", + } + ) + safe = loader.get_safe_params() + assert "client_secret" not in safe + assert safe.get("server") == "srv.fabric.microsoft.com" + + +# =========================================================================== +# Tests: safe_select_list — unsupported type handling +# =========================================================================== + + +class TestSafeSelectList: + def _make_loader(self) -> "AzureSQLDataLoader": + from data_formulator.data_loader.azure_sql_data_loader import AzureSQLDataLoader + + cx = _make_cx_mock() + with patch.dict("sys.modules", {"connectorx": cx}): + loader = AzureSQLDataLoader( + {"server": "srv", "database": "db", "user": "u", "password": "p"} + ) + return loader + + def test_all_supported_types_returns_star(self): + loader = self._make_loader() + # Override _execute_query to return only supported column types + supported_cols = pa.Table.from_pandas( + pd.DataFrame({"COLUMN_NAME": ["id", "name"], "DATA_TYPE": ["int", "nvarchar"]}) + ) + loader._execute_query = MagicMock(return_value=supported_cols) + result = loader._safe_select_list("dbo", "mytable") + assert result == "*" + + def test_geometry_type_uses_stastext(self): + loader = self._make_loader() + cols_with_geo = pa.Table.from_pandas( + pd.DataFrame( + {"COLUMN_NAME": ["id", "location"], "DATA_TYPE": ["int", "geometry"]} + ) + ) + loader._execute_query = MagicMock(return_value=cols_with_geo) + result = loader._safe_select_list("dbo", "spatial_table") + assert "STAsText()" in result + assert "[id]" in result + + def test_xml_type_cast_to_nvarchar(self): + loader = self._make_loader() + cols_with_xml = pa.Table.from_pandas( + pd.DataFrame( + {"COLUMN_NAME": ["id", "doc"], "DATA_TYPE": ["int", "xml"]} + ) + ) + loader._execute_query = MagicMock(return_value=cols_with_xml) + result = loader._safe_select_list("dbo", "xml_table") + assert "NVARCHAR(MAX)" in result + assert "[doc]" in result + + def test_exception_returns_star(self): + loader = self._make_loader() + loader._execute_query = MagicMock(side_effect=RuntimeError("boom")) + result = loader._safe_select_list("dbo", "badtable") + assert result == "*"