Skip to content

Commit 4a49c50

Browse files
committed
support unnest
1 parent be10b18 commit 4a49c50

File tree

2 files changed

+56
-41
lines changed

2 files changed

+56
-41
lines changed

src/sql_data_guard/sql_data_guard.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, config: dict, dialect: str):
2525
self._errors = set()
2626
self._fixed = None
2727
self._config = config
28-
self._dynamic_tables: Set[str] = set()
28+
self._dynamic_tables: List[expr.TableAlias] = []
2929
self._dialect = dialect
3030
self._risk: List[float] = []
3131

@@ -56,9 +56,13 @@ def config(self) -> dict:
5656
return self._config
5757

5858
@property
59-
def dynamic_tables(self) -> Set[str]:
59+
def dynamic_tables(self) -> List[expr.TableAlias]:
6060
return self._dynamic_tables
6161

62+
@property
63+
def dynamic_tables_names(self) -> Set[str]:
64+
return {t.alias_or_name for t in self._dynamic_tables}
65+
6266
@property
6367
def dialect(self) -> str:
6468
return self._dialect
@@ -255,13 +259,16 @@ def _verify_query_statement(query_statement: expr.Query, context: _VerificationC
255259
_verify_query_statement(query_statement.right, context)
256260
return
257261
for cte in query_statement.ctes:
258-
context.dynamic_tables.add(cte.alias)
262+
_add_table_alias(cte, context)
259263
_verify_query_statement(cte.this, context)
260264
from_tables = _get_from_clause_tables(query_statement, context)
261265
for t in from_tables:
262266
found = False
263267
for config_t in context.config["tables"]:
264-
if t.name == config_t["table_name"] or t.name in context.dynamic_tables:
268+
if (
269+
t.name == config_t["table_name"]
270+
or t.name in context.dynamic_tables_names
271+
):
265272
found = True
266273
if not found:
267274
context.add_error(f"Table {t.name} is not allowed", False, 1)
@@ -332,42 +339,30 @@ def _verify_col(
332339
"""
333340
if (
334341
col.table == "sub_select"
335-
or col.table != ""
336-
and col.table in context.dynamic_tables
342+
or (col.table != "" and col.table in context.dynamic_tables_names)
343+
or (all(t.name in context.dynamic_tables_names for t in from_tables))
344+
or (
345+
col.table == ""
346+
and col.name
347+
in [c.alias_or_name for t in context.dynamic_tables for c in t.columns]
348+
)
349+
or (
350+
any(
351+
col.name in config_t["columns"]
352+
for config_t in context.config["tables"]
353+
for t in from_tables
354+
if t.name == config_t["table_name"]
355+
)
356+
)
337357
):
338-
pass
339-
elif not _find_column(col.name, from_tables, context):
358+
return True
359+
else:
340360
context.add_error(
341361
f"Column {col.name} is not allowed. Column removed from SELECT clause",
342362
True,
343363
0.3,
344364
)
345365
return False
346-
return True
347-
348-
349-
def _find_column(
350-
col_name: str, from_tables: List[expr.Table], result: _VerificationContext
351-
) -> bool:
352-
"""
353-
Finds a column in the given tables based on the provided column name.
354-
355-
Args:
356-
col_name (str): The name of the column to find.
357-
from_tables (List[expr.Table]): The list of tables to search within.
358-
result (_VerificationContext): The context for verification.
359-
360-
Returns:
361-
bool: True if the column is found in any of the tables, False otherwise.
362-
"""
363-
if all(t.name in result.dynamic_tables for t in from_tables):
364-
return True
365-
for t in from_tables:
366-
for config_t in result.config["tables"]:
367-
if t.name == config_t["table_name"]:
368-
if col_name in config_t["columns"]:
369-
return True
370-
return False
371366

372367

373368
def _get_from_clause_tables(
@@ -391,18 +386,24 @@ def _get_from_clause_tables(
391386
for t in _find_direct(clause, expr.Table):
392387
if isinstance(t, expr.Table):
393388
result.append(t)
394-
for j in _find_direct(clause, expr.Subquery):
395-
if j.alias != "":
396-
context.dynamic_tables.add(j.alias)
397-
_verify_query_statement(j.this, context)
389+
for l in _find_direct(clause, expr.Subquery):
390+
_add_table_alias(l, context)
391+
_verify_query_statement(l.this, context)
398392
if join_clause:
399-
for j in _find_direct(clause, expr.Lateral):
400-
if j.alias != "":
401-
context.dynamic_tables.add(j.alias)
402-
_verify_query_statement(j.this.find(expr.Select), context)
393+
for l in _find_direct(join_clause, expr.Lateral):
394+
_add_table_alias(l, context)
395+
_verify_query_statement(l.this.find(expr.Select), context)
396+
for u in _find_direct(join_clause, expr.Unnest):
397+
_add_table_alias(u, context)
403398
return result
404399

405400

401+
def _add_table_alias(exp: expr.Expression, context: _VerificationContext):
402+
for table_alias in _find_direct(exp, expr.TableAlias):
403+
if isinstance(table_alias, expr.TableAlias):
404+
context.dynamic_tables.append(table_alias)
405+
406+
406407
def _split_to_expressions(
407408
exp: expr.Expression, exp_type: Type[expr.Expression]
408409
) -> Generator[expr.Expression, None, None]:

test/test_sql_guard_unit.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,20 @@ def test_round_transform(self, config):
344344
dialect="trino",
345345
)
346346

347+
def test_cross_join_unnest_access_column_with_alias(self, config):
348+
verify_sql_test(
349+
"SELECT t.val FROM highlights CROSS JOIN UNNEST(vals) AS t(val)",
350+
config,
351+
dialect="trino",
352+
)
353+
354+
def test_cross_join_unnest_access_column_without_alias(self, config):
355+
verify_sql_test(
356+
"SELECT val FROM highlights CROSS JOIN UNNEST(vals) AS t(val)",
357+
config,
358+
dialect="trino",
359+
)
360+
347361

348362
class TestRestrictionsWithDifferentDataTypes:
349363
@pytest.fixture(scope="class")

0 commit comments

Comments
 (0)