@@ -25,15 +25,17 @@ def __init__(self, config: dict, dialect: str):
2525 self ._config = config
2626 self ._dynamic_tables : Set [str ] = set ()
2727 self ._dialect = dialect
28+ self ._risk : List [float ] = []
2829
2930 @property
3031 def can_fix (self ) -> bool :
3132 return self ._can_fix
3233
33- def add_error (self , error : str , can_fix : bool = True ):
34+ def add_error (self , error : str , can_fix : bool , risk : float ):
3435 self ._errors .add (error )
3536 if not can_fix :
3637 self ._can_fix = False
38+ self ._risk .append (risk )
3739
3840 @property
3941 def errors (self ) -> Set [str ]:
@@ -60,6 +62,10 @@ def dynamic_tables(self) -> Set[str]:
6062 def dialect (self ) -> str :
6163 return self ._dialect
6264
65+ @property
66+ def risk (self ) -> float :
67+ return sum (self ._risk ) / len (self ._risk ) if len (self ._risk ) > 0 else 0
68+
6369
6470def verify_sql (sql : str , config : dict , dialect : str = None ) -> dict :
6571 """
@@ -75,19 +81,23 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
7581 - "allowed" (bool): Whether the query is allowed to run.
7682 - "errors" (List[str]): List of errors found during verification.
7783 - "fixed" (Optional[str]): The fixed query if modifications were made.
84+ - "risk" (float): Verification risk score (0 - no risk, 1 - high risk)
7885 """
86+ result = _VerificationContext (config , dialect )
7987 try :
8088 parsed = sqlglot .parse_one (sql , dialect = dialect )
8189 except sqlglot .errors .ParseError as e :
8290 logging .error (f"SQL: { sql } \n Error parsing SQL: { e } " )
83- return { "allowed" : False , "errors" : e .errors }
84- if not isinstance (parsed , expr .Select ):
85- return {"allowed" : False , "errors" : ["Could not find a select statement" ]}
86- result = _VerificationContext (config , dialect )
87- _verify_select_statement (parsed , result )
91+ result .add_error (f"Error parsing sql: { e } " , False , 0.9 )
92+ parsed = None
93+ if parsed :
94+ if not isinstance (parsed , expr .Select ):
95+ result .add_error ("Could not find a select statement" , False , 0.7 )
96+ else :
97+ _verify_select_statement (parsed , result )
8898 if result .can_fix and len (result .errors ) > 0 :
8999 result .fixed = parsed .sql ()
90- return { "allowed" : len (result .errors ) == 0 , "errors" : result .errors , "fixed" : result .fixed }
100+ return { "allowed" : len (result .errors ) == 0 , "errors" : result .errors , "fixed" : result .fixed , "risk" : result . risk }
91101
92102
93103def _verify_where_clause (result : _VerificationContext , select_statement : expr .Select ,
@@ -108,7 +118,9 @@ def _verify_where_clause(result: _VerificationContext, select_statement: expr.Se
108118 found = True
109119 break
110120 if not found :
111- result .add_error (f"Missing restriction for table: { t ['table_name' ]} column: { r ['column' ]} value: { r ['value' ]} " )
121+ result .add_error (
122+ f"Missing restriction for table: { t ['table_name' ]} column: { r ['column' ]} value: { r ['value' ]} " ,
123+ True , 0.5 )
112124 value = f"'{ r ['value' ]} '" if isinstance (r ["value" ], str ) else r ["value" ]
113125 new_condition = sqlglot .parse_one (f"{ r ['column' ]} = { value } " , dialect = result .dialect )
114126 if where_clause is None :
@@ -135,7 +147,7 @@ def _has_static_expression(context: _VerificationContext, exp: expr.Expression)
135147 result = _has_static_expression (context , sub_exp )
136148 elif not sub_exp .find (expr .Column ):
137149 context .add_error (
138- f"Static expression is not allowed: { sub_exp .sql ()} " , False )
150+ f"Static expression is not allowed: { sub_exp .sql ()} " , False , 0.9 )
139151 result = True
140152 return result
141153
@@ -181,7 +193,7 @@ def _verify_select_statement(select_statement: expr.Select,
181193 if t .name == config_t ["table_name" ] or t .name in context .dynamic_tables :
182194 found = True
183195 if not found :
184- context .add_error (f"Table { t .name } is not allowed" , False )
196+ context .add_error (f"Table { t .name } is not allowed" , False , 1 )
185197 if not context .can_fix :
186198 return select_statement
187199 _verify_select_clause (context , select_statement , from_tables )
@@ -199,15 +211,15 @@ def _verify_select_clause(context: _VerificationContext,
199211 for e in to_remove :
200212 select_clause .expressions .remove (e )
201213 if len (select_clause .expressions ) == 0 :
202- context .add_error ("No legal elements in SELECT clause" , False )
214+ context .add_error ("No legal elements in SELECT clause" , False , 0.5 )
203215
204216def _verify_select_clause_element (from_tables : List [expr .Table ], context : _VerificationContext ,
205217 e : expr .Expression ):
206218 if isinstance (e , expr .Column ):
207219 if not _verify_col (e , from_tables , context ):
208220 return False
209221 elif isinstance (e , expr .Star ):
210- context .add_error ("SELECT * is not allowed" , True )
222+ context .add_error ("SELECT * is not allowed" , True , 0.1 )
211223 for t in from_tables :
212224 for config_t in context .config ["tables" ]:
213225 if t .name == config_t ["table_name" ]:
@@ -241,7 +253,8 @@ def _verify_col(col: expr.Column, from_tables: List[expr.Table], context: _Verif
241253 if col .table == "sub_select" or col .table != "" and col .table in context .dynamic_tables :
242254 pass
243255 elif not _find_column (col .name , from_tables , context ):
244- context .add_error (f"Column { col .name } is not allowed. Column removed from SELECT clause" )
256+ context .add_error (f"Column { col .name } is not allowed. Column removed from SELECT clause" ,
257+ True ,0.3 )
245258 return False
246259 return True
247260
0 commit comments