@@ -18,6 +18,7 @@ class _VerificationContext:
1818 _dynamic_tables (Set[str]): Set of dynamic tables found in the query, like sub select and WITH clauses.
1919 _dialect (str): The SQL dialect to use for parsing.
2020 """
21+
2122 def __init__ (self , config : dict , dialect : str ):
2223 super ().__init__ ()
2324 self ._can_fix = True
@@ -50,7 +51,6 @@ def fixed(self) -> Optional[str]:
5051 def fixed (self , value : Optional [str ]):
5152 self ._fixed = value
5253
53-
5454 @property
5555 def config (self ) -> dict :
5656 return self ._config
@@ -95,31 +95,48 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
9595 if isinstance (parsed , expr .Command ):
9696 result .add_error (f"{ parsed .name } statement is not allowed" , False , 0.9 )
9797 elif isinstance (parsed , expr .Delete ) or isinstance (parsed , expr .Insert ):
98- result .add_error (f"{ parsed .key .upper ()} statement is not allowed" , False , 0.9 )
98+ result .add_error (
99+ f"{ parsed .key .upper ()} statement is not allowed" , False , 0.9
100+ )
99101 elif isinstance (parsed , expr .Query ):
100102 _verify_query_statement (parsed , result )
101103 else :
102104 result .add_error ("Could not find a query statement" , False , 0.7 )
103105 if result .can_fix and len (result .errors ) > 0 :
104106 result .fixed = parsed .sql ()
105- return { "allowed" : len (result .errors ) == 0 , "errors" : result .errors , "fixed" : result .fixed , "risk" : result .risk }
106-
107-
108- def _verify_where_clause (context : _VerificationContext , select_statement : expr .Query ,
109- from_tables : List [expr .Table ]):
107+ return {
108+ "allowed" : len (result .errors ) == 0 ,
109+ "errors" : result .errors ,
110+ "fixed" : result .fixed ,
111+ "risk" : result .risk ,
112+ }
113+
114+
115+ def _verify_where_clause (
116+ context : _VerificationContext ,
117+ select_statement : expr .Query ,
118+ from_tables : List [expr .Table ],
119+ ):
110120 _verify_static_expression (select_statement , context )
111121 _verify_restrictions (select_statement , context , from_tables )
112122
113- def _verify_restrictions (select_statement : expr .Query ,
114- context : _VerificationContext ,
115- from_tables : List [expr .Table ]):
123+
124+ def _verify_restrictions (
125+ select_statement : expr .Query ,
126+ context : _VerificationContext ,
127+ from_tables : List [expr .Table ],
128+ ):
116129 where_clause = select_statement .find (expr .Where )
117130 if where_clause is None :
118131 where_clause = select_statement .find (expr .Where )
119132 and_exps = []
120133 else :
121134 and_exps = list (_split_to_expressions (where_clause .this , expr .And ))
122- for t in [c_t for c_t in context .config ["tables" ] if c_t ["table_name" ] in [t .name for t in from_tables ]]:
135+ for t in [
136+ c_t
137+ for c_t in context .config ["tables" ]
138+ if c_t ["table_name" ] in [t .name for t in from_tables ]
139+ ]:
123140 for idx , r in enumerate (t .get ("restrictions" , [])):
124141 found = False
125142 for sub_exp in and_exps :
@@ -129,18 +146,30 @@ def _verify_restrictions(select_statement: expr.Query,
129146 if not found :
130147 context .add_error (
131148 f"Missing restriction for table: { t ['table_name' ]} column: { r ['column' ]} value: { r ['value' ]} " ,
132- True , 0.5 )
149+ True ,
150+ 0.5 ,
151+ )
133152 value = f"'{ r ['value' ]} '" if isinstance (r ["value" ], str ) else r ["value" ]
134- new_condition = sqlglot .parse_one (f"{ r ['column' ]} = { value } " , dialect = context .dialect )
153+ new_condition = sqlglot .parse_one (
154+ f"{ r ['column' ]} = { value } " , dialect = context .dialect
155+ )
135156 if where_clause is None :
136157 where_clause = expr .Where (this = new_condition )
137158 select_statement .set ("where" , where_clause )
138159 else :
139- where_clause = where_clause .replace (expr .Where (this = expr .And (this = expr .paren (where_clause .this ),
140- expression = new_condition )))
141-
142-
143- def _verify_static_expression (select_statement : expr .Query , context : _VerificationContext ) -> bool :
160+ where_clause = where_clause .replace (
161+ expr .Where (
162+ this = expr .And (
163+ this = expr .paren (where_clause .this ),
164+ expression = new_condition ,
165+ )
166+ )
167+ )
168+
169+
170+ def _verify_static_expression (
171+ select_statement : expr .Query , context : _VerificationContext
172+ ) -> bool :
144173 has_static_exp = False
145174 where_clause = select_statement .find (expr .Where )
146175 if where_clause :
@@ -152,6 +181,7 @@ def _verify_static_expression(select_statement: expr.Query, context: _Verificati
152181 simplify (where_clause )
153182 return not has_static_exp
154183
184+
155185def _has_static_expression (context : _VerificationContext , exp : expr .Expression ) -> bool :
156186 if isinstance (exp , expr .Not ):
157187 return _has_static_expression (context , exp .this )
@@ -162,7 +192,8 @@ def _has_static_expression(context: _VerificationContext, exp: expr.Expression)
162192 result = _has_static_expression (context , sub_exp )
163193 elif not sub_exp .find (expr .Column ):
164194 context .add_error (
165- f"Static expression is not allowed: { sub_exp .sql ()} " , True , 0.8 )
195+ f"Static expression is not allowed: { sub_exp .sql ()} " , True , 0.8
196+ )
166197 par = sub_exp .parent
167198 while isinstance (par , expr .Paren ):
168199 par = par .parent
@@ -176,15 +207,15 @@ def _has_static_expression(context: _VerificationContext, exp: expr.Expression)
176207
177208def _verify_restriction (restriction : dict , exp : expr .Expression ) -> bool :
178209 """
179- Verifies if a given restriction is satisfied within an SQL expression.
210+ Verifies if a given restriction is satisfied within an SQL expression.
180211
181- Args:
182- restriction (dict): The restriction to verify, containing 'column' and 'value' keys.
183- exp (list): The SQL expression to check against the restriction.
212+ Args:
213+ restriction (dict): The restriction to verify, containing 'column' and 'value' keys.
214+ exp (list): The SQL expression to check against the restriction.
184215
185- Returns:
186- bool: True if the restriction is satisfied, False otherwise.
187- """
216+ Returns:
217+ bool: True if the restriction is satisfied, False otherwise.
218+ """
188219 if isinstance (exp , expr .Not ):
189220 return False
190221 if isinstance (exp , expr .Paren ):
@@ -202,8 +233,8 @@ def _verify_restriction(restriction: dict, exp: expr.Expression) -> bool:
202233 return exp .right .this == str (restriction ["value" ])
203234 return False
204235
205- def _verify_query_statement ( query_statement : expr . Query ,
206- context : _VerificationContext ):
236+
237+ def _verify_query_statement ( query_statement : expr . Query , context : _VerificationContext ):
207238 if isinstance (query_statement , expr .Union ):
208239 _verify_query_statement (query_statement .left , context )
209240 _verify_query_statement (query_statement .right , context )
@@ -226,9 +257,11 @@ def _verify_query_statement(query_statement: expr.Query,
226257 return query_statement
227258
228259
229- def _verify_select_clause (context : _VerificationContext ,
230- select_clause : expr .Query ,
231- from_tables : List [expr .Table ]):
260+ def _verify_select_clause (
261+ context : _VerificationContext ,
262+ select_clause : expr .Query ,
263+ from_tables : List [expr .Table ],
264+ ):
232265 to_remove = []
233266 for e in select_clause .expressions :
234267 if not _verify_select_clause_element (from_tables , context , e ):
@@ -238,8 +271,10 @@ def _verify_select_clause(context: _VerificationContext,
238271 if len (select_clause .expressions ) == 0 :
239272 context .add_error ("No legal elements in SELECT clause" , False , 0.5 )
240273
241- def _verify_select_clause_element (from_tables : List [expr .Table ], context : _VerificationContext ,
242- e : expr .Expression ):
274+
275+ def _verify_select_clause_element (
276+ from_tables : List [expr .Table ], context : _VerificationContext , e : expr .Expression
277+ ):
243278 if isinstance (e , expr .Column ):
244279 if not _verify_col (e , from_tables , context ):
245280 return False
@@ -249,7 +284,9 @@ def _verify_select_clause_element(from_tables: List[expr.Table], context: _Verif
249284 for config_t in context .config ["tables" ]:
250285 if t .name == config_t ["table_name" ]:
251286 for c in config_t ["columns" ]:
252- e .parent .set ("expressions" , e .parent .expressions + [sqlglot .parse_one (c )])
287+ e .parent .set (
288+ "expressions" , e .parent .expressions + [sqlglot .parse_one (c )]
289+ )
253290 return False
254291 elif isinstance (e , expr .Tuple ):
255292 result = True
@@ -263,7 +300,10 @@ def _verify_select_clause_element(from_tables: List[expr.Table], context: _Verif
263300 return False
264301 return True
265302
266- def _verify_col (col : expr .Column , from_tables : List [expr .Table ], context : _VerificationContext ) -> bool :
303+
304+ def _verify_col (
305+ col : expr .Column , from_tables : List [expr .Table ], context : _VerificationContext
306+ ) -> bool :
267307 """
268308 Verifies if a column reference is allowed based on the provided tables and context.
269309
@@ -275,16 +315,25 @@ def _verify_col(col: expr.Column, from_tables: List[expr.Table], context: _Verif
275315 Returns:
276316 bool: True if the column reference is allowed, False otherwise.
277317 """
278- if col .table == "sub_select" or col .table != "" and col .table in context .dynamic_tables :
318+ if (
319+ col .table == "sub_select"
320+ or col .table != ""
321+ and col .table in context .dynamic_tables
322+ ):
279323 pass
280324 elif not _find_column (col .name , from_tables , context ):
281- context .add_error (f"Column { col .name } is not allowed. Column removed from SELECT clause" ,
282- True ,0.3 )
325+ context .add_error (
326+ f"Column { col .name } is not allowed. Column removed from SELECT clause" ,
327+ True ,
328+ 0.3 ,
329+ )
283330 return False
284331 return True
285332
286333
287- def _find_column (col_name : str , from_tables : List [expr .Table ], result : _VerificationContext ) -> bool :
334+ def _find_column (
335+ col_name : str , from_tables : List [expr .Table ], result : _VerificationContext
336+ ) -> bool :
288337 """
289338 Finds a column in the given tables based on the provided column name.
290339
@@ -306,16 +355,18 @@ def _find_column(col_name: str, from_tables: List[expr.Table], result: _Verifica
306355 return False
307356
308357
309- def _get_from_clause_tables (select_clause : expr .Query , context : _VerificationContext ) -> List [expr .Table ]:
358+ def _get_from_clause_tables (
359+ select_clause : expr .Query , context : _VerificationContext
360+ ) -> List [expr .Table ]:
310361 """
311- Extracts table references from the FROM clause of an SQL query.
362+ Extracts table references from the FROM clause of an SQL query.
312363
313- Args:
314- select_clause (dict): The FROM clause of the SQL query.
315- context (_VerificationContext): The context for verification.
364+ Args:
365+ select_clause (dict): The FROM clause of the SQL query.
366+ context (_VerificationContext): The context for verification.
316367
317- Returns:
318- List[_TableRef]: A list of table references to find in the FROM clause.
368+ Returns:
369+ List[_TableRef]: A list of table references to find in the FROM clause.
319370 """
320371 result = []
321372 from_clause = select_clause .find (expr .From )
@@ -337,13 +388,15 @@ def _get_from_clause_tables(select_clause: expr.Query, context: _VerificationCon
337388 return result
338389
339390
340- def _split_to_expressions (exp : expr .Expression ,
341- exp_type : Type [expr .Expression ]) -> Generator [expr .Expression , None , None ]:
391+ def _split_to_expressions (
392+ exp : expr .Expression , exp_type : Type [expr .Expression ]
393+ ) -> Generator [expr .Expression , None , None ]:
342394 if isinstance (exp , exp_type ):
343395 yield from exp .flatten ()
344396 else :
345397 yield exp
346398
399+
347400def _find_direct (exp : expr .Expression , exp_type : Type [expr .Expression ]):
348401 for child in exp .args .values ():
349402 if isinstance (child , exp_type ):
0 commit comments