Skip to content

Commit 582dcf8

Browse files
committed
add risk
1 parent f8d4f72 commit 582dcf8

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

src/sql_data_guard/sql_data_guard.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@ def __init__(self, config: dict, dialect: str):
2525
self._config = config
2626
self._dynamic_tables: Set[str] = set()
2727
self._dialect = dialect
28+
self._risk: List[float] = []
2829

2930
@property
3031
def can_fix(self) -> bool:
3132
return self._can_fix
3233

33-
def add_error(self, error: str, can_fix: bool = True):
34+
def add_error(self, error: str, can_fix: bool, risk: float):
3435
self._errors.add(error)
3536
if not can_fix:
3637
self._can_fix = False
38+
self._risk.append(risk)
3739

3840
@property
3941
def errors(self) -> Set[str]:
@@ -60,6 +62,10 @@ def dynamic_tables(self) -> Set[str]:
6062
def dialect(self) -> str:
6163
return self._dialect
6264

65+
@property
66+
def risk(self) -> float:
67+
return sum(self._risk) / len(self._risk) if len(self._risk) > 0 else 0
68+
6369

6470
def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
6571
"""
@@ -75,19 +81,23 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
7581
- "allowed" (bool): Whether the query is allowed to run.
7682
- "errors" (List[str]): List of errors found during verification.
7783
- "fixed" (Optional[str]): The fixed query if modifications were made.
84+
- "risk" (float): Verification risk score (0 - no risk, 1 - high risk)
7885
"""
86+
result = _VerificationContext(config, dialect)
7987
try:
8088
parsed = sqlglot.parse_one(sql, dialect=dialect)
8189
except sqlglot.errors.ParseError as e:
8290
logging.error(f"SQL: {sql}\nError parsing SQL: {e}")
83-
return { "allowed": False, "errors": e.errors}
84-
if not isinstance(parsed, expr.Select):
85-
return {"allowed": False, "errors": ["Could not find a select statement"]}
86-
result = _VerificationContext(config, dialect)
87-
_verify_select_statement(parsed, result)
91+
result.add_error(f"Error parsing sql: {e}", False, 0.9)
92+
parsed = None
93+
if parsed:
94+
if not isinstance(parsed, expr.Select):
95+
result.add_error("Could not find a select statement", False, 0.7)
96+
else:
97+
_verify_select_statement(parsed, result)
8898
if result.can_fix and len(result.errors) > 0:
8999
result.fixed = parsed.sql()
90-
return { "allowed": len(result.errors) == 0, "errors": result.errors, "fixed": result.fixed }
100+
return { "allowed": len(result.errors) == 0, "errors": result.errors, "fixed": result.fixed, "risk": result.risk}
91101

92102

93103
def _verify_where_clause(result: _VerificationContext, select_statement: expr.Select,
@@ -108,7 +118,9 @@ def _verify_where_clause(result: _VerificationContext, select_statement: expr.Se
108118
found = True
109119
break
110120
if not found:
111-
result.add_error(f"Missing restriction for table: {t['table_name']} column: {r['column']} value: {r['value']}")
121+
result.add_error(
122+
f"Missing restriction for table: {t['table_name']} column: {r['column']} value: {r['value']}",
123+
True, 0.5)
112124
value = f"'{r['value']}'" if isinstance(r["value"], str) else r["value"]
113125
new_condition = sqlglot.parse_one(f"{r['column']} = {value}", dialect=result.dialect)
114126
if where_clause is None:
@@ -135,7 +147,7 @@ def _has_static_expression(context: _VerificationContext, exp: expr.Expression)
135147
result = _has_static_expression(context, sub_exp)
136148
elif not sub_exp.find(expr.Column):
137149
context.add_error(
138-
f"Static expression is not allowed: {sub_exp.sql()}", False)
150+
f"Static expression is not allowed: {sub_exp.sql()}", False, 0.9)
139151
result = True
140152
return result
141153

@@ -181,7 +193,7 @@ def _verify_select_statement(select_statement: expr.Select,
181193
if t.name == config_t["table_name"] or t.name in context.dynamic_tables:
182194
found = True
183195
if not found:
184-
context.add_error(f"Table {t.name} is not allowed", False)
196+
context.add_error(f"Table {t.name} is not allowed", False, 1)
185197
if not context.can_fix:
186198
return select_statement
187199
_verify_select_clause(context, select_statement, from_tables)
@@ -199,15 +211,15 @@ def _verify_select_clause(context: _VerificationContext,
199211
for e in to_remove:
200212
select_clause.expressions.remove(e)
201213
if len(select_clause.expressions) == 0:
202-
context.add_error("No legal elements in SELECT clause", False)
214+
context.add_error("No legal elements in SELECT clause", False, 0.5)
203215

204216
def _verify_select_clause_element(from_tables: List[expr.Table], context: _VerificationContext,
205217
e: expr.Expression):
206218
if isinstance(e, expr.Column):
207219
if not _verify_col(e, from_tables, context):
208220
return False
209221
elif isinstance(e, expr.Star):
210-
context.add_error("SELECT * is not allowed", True)
222+
context.add_error("SELECT * is not allowed", True, 0.1)
211223
for t in from_tables:
212224
for config_t in context.config["tables"]:
213225
if t.name == config_t["table_name"]:
@@ -241,7 +253,8 @@ def _verify_col(col: expr.Column, from_tables: List[expr.Table], context: _Verif
241253
if col.table == "sub_select" or col.table != "" and col.table in context.dynamic_tables:
242254
pass
243255
elif not _find_column(col.name, from_tables, context):
244-
context.add_error(f"Column {col.name} is not allowed. Column removed from SELECT clause")
256+
context.add_error(f"Column {col.name} is not allowed. Column removed from SELECT clause",
257+
True,0.3)
245258
return False
246259
return True
247260

test/test_sql_guard_unit.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def _test_sql(sql: str, config: dict, errors: Set[str] = None, fix: str = None,
1717
assert result["errors"] == set()
1818
else:
1919
assert set(result["errors"]) == set(errors)
20+
if len(result["errors"]) > 0:
21+
assert result["risk"] > 0
22+
else:
23+
assert result["risk"] == 0
2024
if fix is None:
2125
assert result.get("fixed") is None
2226
sql_to_use = sql
@@ -41,7 +45,9 @@ class TestSQLErrors:
4145
def test_basic_sql_error(self):
4246
result = verify_sql("this is not an sql statement ",{})
4347
assert result["allowed"] == False
44-
assert "Invalid expression / Unexpected token" in result["errors"][0]["description"]
48+
assert len(result["errors"]) == 1
49+
error = next(iter(result["errors"]))
50+
assert "Invalid expression / Unexpected token" in error
4551

4652

4753
class TestSingleTable:
@@ -97,6 +103,11 @@ def test_by_name(self, test_name, config, cnn, tests):
97103
_test_sql(test["sql"], config, set(test.get("errors", [])),
98104
test.get("fix"), cnn=cnn, data=test.get("data"))
99105

106+
def test_risk(self, config):
107+
result = verify_sql("SELECT * FROM orders", config)
108+
assert result["risk"] > 0
109+
110+
100111

101112
class TestJoinTable:
102113

0 commit comments

Comments
 (0)