@@ -25,7 +25,7 @@ def __init__(self, config: dict, dialect: str):
2525 self ._errors = set ()
2626 self ._fixed = None
2727 self ._config = config
28- self ._dynamic_tables : Set [ str ] = set ()
28+ self ._dynamic_tables : List [ expr . TableAlias ] = []
2929 self ._dialect = dialect
3030 self ._risk : List [float ] = []
3131
@@ -56,9 +56,13 @@ def config(self) -> dict:
5656 return self ._config
5757
5858 @property
59- def dynamic_tables (self ) -> Set [ str ]:
59+ def dynamic_tables (self ) -> List [ expr . TableAlias ]:
6060 return self ._dynamic_tables
6161
62+ @property
63+ def dynamic_tables_names (self ) -> Set [str ]:
64+ return {t .alias_or_name for t in self ._dynamic_tables }
65+
6266 @property
6367 def dialect (self ) -> str :
6468 return self ._dialect
@@ -255,13 +259,16 @@ def _verify_query_statement(query_statement: expr.Query, context: _VerificationC
255259 _verify_query_statement (query_statement .right , context )
256260 return
257261 for cte in query_statement .ctes :
258- context . dynamic_tables . add (cte . alias )
262+ _add_table_alias (cte , context )
259263 _verify_query_statement (cte .this , context )
260264 from_tables = _get_from_clause_tables (query_statement , context )
261265 for t in from_tables :
262266 found = False
263267 for config_t in context .config ["tables" ]:
264- if t .name == config_t ["table_name" ] or t .name in context .dynamic_tables :
268+ if (
269+ t .name == config_t ["table_name" ]
270+ or t .name in context .dynamic_tables_names
271+ ):
265272 found = True
266273 if not found :
267274 context .add_error (f"Table { t .name } is not allowed" , False , 1 )
@@ -332,42 +339,30 @@ def _verify_col(
332339 """
333340 if (
334341 col .table == "sub_select"
335- or col .table != ""
336- and col .table in context .dynamic_tables
342+ or (col .table != "" and col .table in context .dynamic_tables_names )
343+ or (all (t .name in context .dynamic_tables_names for t in from_tables ))
344+ or (
345+ col .table == ""
346+ and col .name
347+ in [c .alias_or_name for t in context .dynamic_tables for c in t .columns ]
348+ )
349+ or (
350+ any (
351+ col .name in config_t ["columns" ]
352+ for config_t in context .config ["tables" ]
353+ for t in from_tables
354+ if t .name == config_t ["table_name" ]
355+ )
356+ )
337357 ):
338- pass
339- elif not _find_column ( col . name , from_tables , context ) :
358+ return True
359+ else :
340360 context .add_error (
341361 f"Column { col .name } is not allowed. Column removed from SELECT clause" ,
342362 True ,
343363 0.3 ,
344364 )
345365 return False
346- return True
347-
348-
349- def _find_column (
350- col_name : str , from_tables : List [expr .Table ], result : _VerificationContext
351- ) -> bool :
352- """
353- Finds a column in the given tables based on the provided column name.
354-
355- Args:
356- col_name (str): The name of the column to find.
357- from_tables (List[expr.Table]): The list of tables to search within.
358- result (_VerificationContext): The context for verification.
359-
360- Returns:
361- bool: True if the column is found in any of the tables, False otherwise.
362- """
363- if all (t .name in result .dynamic_tables for t in from_tables ):
364- return True
365- for t in from_tables :
366- for config_t in result .config ["tables" ]:
367- if t .name == config_t ["table_name" ]:
368- if col_name in config_t ["columns" ]:
369- return True
370- return False
371366
372367
373368def _get_from_clause_tables (
@@ -391,18 +386,24 @@ def _get_from_clause_tables(
391386 for t in _find_direct (clause , expr .Table ):
392387 if isinstance (t , expr .Table ):
393388 result .append (t )
394- for j in _find_direct (clause , expr .Subquery ):
395- if j .alias != "" :
396- context .dynamic_tables .add (j .alias )
397- _verify_query_statement (j .this , context )
389+ for l in _find_direct (clause , expr .Subquery ):
390+ _add_table_alias (l , context )
391+ _verify_query_statement (l .this , context )
398392 if join_clause :
399- for j in _find_direct (clause , expr .Lateral ):
400- if j .alias != "" :
401- context .dynamic_tables .add (j .alias )
402- _verify_query_statement (j .this .find (expr .Select ), context )
393+ for l in _find_direct (join_clause , expr .Lateral ):
394+ _add_table_alias (l , context )
395+ _verify_query_statement (l .this .find (expr .Select ), context )
396+ for u in _find_direct (join_clause , expr .Unnest ):
397+ _add_table_alias (u , context )
403398 return result
404399
405400
401+ def _add_table_alias (exp : expr .Expression , context : _VerificationContext ):
402+ for table_alias in _find_direct (exp , expr .TableAlias ):
403+ if isinstance (table_alias , expr .TableAlias ):
404+ context .dynamic_tables .append (table_alias )
405+
406+
406407def _split_to_expressions (
407408 exp : expr .Expression , exp_type : Type [expr .Expression ]
408409) -> Generator [expr .Expression , None , None ]:
0 commit comments