Skip to content

Commit cdf9934

Browse files
committed
remove static expressions
1 parent 95d43a5 commit cdf9934

File tree

3 files changed

+41
-20
lines changed

3 files changed

+41
-20
lines changed

src/sql_data_guard/sql_data_guard.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import sqlglot
55
import sqlglot.expressions as expr
6+
from sqlglot.optimizer.simplify import simplify
67

78

89
class _VerificationContext:
@@ -104,29 +105,33 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
104105
return { "allowed": len(result.errors) == 0, "errors": result.errors, "fixed": result.fixed, "risk": result.risk}
105106

106107

107-
def _verify_where_clause(result: _VerificationContext, select_statement: expr.Query,
108+
def _verify_where_clause(context: _VerificationContext, select_statement: expr.Query,
109+
from_tables: List[expr.Table]):
110+
_verify_static_expression(select_statement, context)
111+
_verify_restrictions(select_statement, context, from_tables)
112+
113+
def _verify_restrictions(select_statement: expr.Query,
114+
context: _VerificationContext,
108115
from_tables: List[expr.Table]):
109116
where_clause = select_statement.find(expr.Where)
110117
if where_clause is None:
111118
where_clause = select_statement.find(expr.Where)
112119
and_exps = []
113120
else:
114121
and_exps = list(_split_to_expressions(where_clause.this, expr.And))
115-
if not _verify_static_expression(result, and_exps):
116-
return
117-
for t in [c_t for c_t in result.config["tables"] if c_t["table_name"] in [t.name for t in from_tables]]:
122+
for t in [c_t for c_t in context.config["tables"] if c_t["table_name"] in [t.name for t in from_tables]]:
118123
for idx, r in enumerate(t.get("restrictions", [])):
119124
found = False
120125
for sub_exp in and_exps:
121126
if _verify_restriction(r, sub_exp):
122127
found = True
123128
break
124129
if not found:
125-
result.add_error(
130+
context.add_error(
126131
f"Missing restriction for table: {t['table_name']} column: {r['column']} value: {r['value']}",
127132
True, 0.5)
128133
value = f"'{r['value']}'" if isinstance(r["value"], str) else r["value"]
129-
new_condition = sqlglot.parse_one(f"{r['column']} = {value}", dialect=result.dialect)
134+
new_condition = sqlglot.parse_one(f"{r['column']} = {value}", dialect=context.dialect)
130135
if where_clause is None:
131136
where_clause = expr.Where(this=new_condition)
132137
select_statement.set("where", where_clause)
@@ -135,24 +140,37 @@ def _verify_where_clause(result: _VerificationContext, select_statement: expr.Qu
135140
expression=new_condition)))
136141

137142

138-
def _verify_static_expression(context: _VerificationContext, exps: List[expr.Expression]) -> bool:
143+
def _verify_static_expression(select_statement: expr.Query, context: _VerificationContext) -> bool:
139144
has_static_exp = False
140-
for e in exps:
141-
if _has_static_expression(context, e):
142-
has_static_exp = True
145+
where_clause = select_statement.find(expr.Where)
146+
if where_clause:
147+
and_exps = list(_split_to_expressions(where_clause.this, expr.And))
148+
for e in and_exps:
149+
if _has_static_expression(context, e):
150+
has_static_exp = True
151+
if has_static_exp:
152+
simplify(where_clause)
143153
return not has_static_exp
144154

145155
def _has_static_expression(context: _VerificationContext, exp: expr.Expression) -> bool:
146156
if isinstance(exp, expr.Not):
147157
return _has_static_expression(context, exp.this)
148158
result = False
159+
to_replace = []
149160
for sub_exp in _split_to_expressions(exp, expr.Or):
150161
if isinstance(sub_exp, (expr.Or, expr.And)):
151162
result = _has_static_expression(context, sub_exp)
152163
elif not sub_exp.find(expr.Column):
153164
context.add_error(
154-
f"Static expression is not allowed: {sub_exp.sql()}", False, 0.9)
165+
f"Static expression is not allowed: {sub_exp.sql()}", True, 0.8)
166+
par = sub_exp.parent
167+
while isinstance(par, expr.Paren):
168+
par = par.parent
169+
if isinstance(par, expr.Or):
170+
to_replace.append(sub_exp)
155171
result = True
172+
for e in to_replace:
173+
e.replace(expr.Boolean(this=False))
156174
return result
157175

158176

@@ -171,7 +189,6 @@ def _verify_restriction(restriction: dict, exp: expr.Expression) -> bool:
171189
return False
172190
if isinstance(exp, expr.Paren):
173191
return _verify_restriction(restriction, exp.this)
174-
175192
if not isinstance(exp.this, expr.Column):
176193
return False
177194
if not exp.this.name == restriction["column"]:

test/resources/orders_test.jsonl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222
{"name": "bad_restriction", "sql": "SELECT id FROM orders WHERE id = 123 OR id = 234", "errors": ["Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE (id = 123 OR id = 234) AND id = 123"}
2323
{"name": "bracketed", "sql": "SELECT id FROM orders WHERE (id = 123)", "errors": [], "data": [[123]]}
2424
{"name": "double_bracketed", "sql": "SELECT id FROM orders WHERE ((id = 123))", "errors": [], "data": [[123]]}
25-
{"name": "static_exp", "sql": "SELECT id FROM orders WHERE id = 123 OR (1 = 1)", "errors": ["Static expression is not allowed: 1 = 1"]}
26-
{"name": "two_static_exps", "sql": "SELECT id FROM orders WHERE id = 123 OR (1 = 1) OR (2 = 2)", "errors": ["Static expression is not allowed: 1 = 1", "Static expression is not allowed: 2 = 2"]}
27-
{"name": "nested_static_exp", "sql": "SELECT id FROM orders WHERE id = 123 OR (id = 1 OR TRUE)", "errors": ["Static expression is not allowed: TRUE"]}
28-
{"name": "nested_static_exp2", "sql": "SELECT id FROM orders WHERE id = 123 AND (product_name = 'product1' OR (TRUE))", "errors": ["Static expression is not allowed: TRUE"]}
25+
{"name": "static_exp", "sql": "SELECT id FROM orders WHERE id = 123 OR 1 = 1", "errors": ["Static expression is not allowed: 1 = 1"], "fix": "SELECT id FROM orders WHERE id = 123", "data": [[123]]}
26+
{"name": "only_static_exp", "sql": "SELECT id FROM orders WHERE 1 = 1", "errors": ["Static expression is not allowed: 1 = 1", "Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE id = 123", "data": [[123]]}
27+
{"name": "only_static_exp_false", "sql": "SELECT id FROM orders WHERE 1 = 0", "errors": ["Static expression is not allowed: 1 = 0", "Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE (FALSE) AND id = 123", "data": []}
28+
{"name": "static_exp_paren", "sql": "SELECT id FROM orders WHERE id = 123 OR (1 = 1)", "errors": ["Static expression is not allowed: 1 = 1"], "fix": "SELECT id FROM orders WHERE id = 123", "data": [[123]]}
29+
{"name": "two_static_exps", "sql": "SELECT id FROM orders WHERE id = 123 OR (1 = 1) OR (2 = 2)", "errors": ["Static expression is not allowed: 1 = 1", "Static expression is not allowed: 2 = 2"], "fix": "SELECT id FROM orders WHERE id = 123", "data": [[123]]}
30+
{"name": "static_exp_with_missing_restriction", "sql": "SELECT id, name FROM orders WHERE 1 = 1", "errors": ["Column name is not allowed. Column removed from SELECT clause", "Static expression is not allowed: 1 = 1", "Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE id = 123", "data": [[123]]}
31+
{"name": "nested_static_exp", "sql": "SELECT id FROM orders WHERE id = 123 OR (id = 1 OR TRUE)", "errors": ["Static expression is not allowed: TRUE", "Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE (id = 1 OR id = 123) AND id = 123", "data": [[123]]}
32+
{"name": "nested_static_exp2", "sql": "SELECT id FROM orders WHERE id = 123 AND (product_name = 'product1' OR (TRUE))", "errors": ["Static expression is not allowed: TRUE"], "fix": "SELECT id FROM orders WHERE id = 123 AND product_name = 'product1'", "data": [[123]]}
2933
{"name": "multiple_brackets_exp", "sql": "SELECT id FROM orders WHERE (( ( (id = 123))))", "errors": [], "data": [[123]]}
3034
{"name": "with_clause", "sql": "WITH data AS (SELECT id FROM orders WHERE id = 123) SELECT id FROM data", "errors": [], "data": [[123]]}
3135
{"name": "nested_with_clause", "sql": "WITH data AS (WITH sub_data AS (SELECT id FROM orders) SELECT id FROM sub_data) SELECT id FROM data", "errors": ["Missing restriction for table: orders column: id value: 123"], "fix": "WITH data AS (WITH sub_data AS (SELECT id FROM orders WHERE id = 123) SELECT id FROM sub_data) SELECT id FROM data"}
@@ -48,10 +52,10 @@
4852
{"name": "no_from_sub_select", "sql": "SELECT id, sub.col FROM orders CROSS JOIN (SELECT 11 AS col) AS sub WHERE id = 123", "errors": [], "data": [[123, 11]]}
4953
{"name": "no_from_sub_select_lateral", "sql": "SELECT id, sub.col FROM orders CROSS JOIN LATERAL (SELECT 11 AS col) AS sub WHERE id = 123", "errors": []}
5054
{"name": "day_between", "sql": "SELECT id FROM orders WHERE DATE(day) BETWEEN DATE('2000-01-01') AND DATE('now','-1 day') AND id = 123", "errors": [], "data": [[123]]}
51-
{"name": "day_between_static_exp", "sql": "SELECT id FROM orders WHERE DATE('2000-01-01') BETWEEN DATE('2000-01-01') AND DATE('2000-01-01') AND id = 123", "errors": ["Static expression is not allowed: DATE('2000-01-01') BETWEEN DATE('2000-01-01') AND DATE('2000-01-01')"], ",data": [[123]]}
55+
{"name": "day_between_static_exp", "sql": "SELECT id FROM orders WHERE DATE('2000-01-01') BETWEEN DATE('2000-01-01') AND DATE('2000-01-01') OR id = 123", "errors": ["Static expression is not allowed: DATE('2000-01-01') BETWEEN DATE('2000-01-01') AND DATE('2000-01-01')"], "fix": "SELECT id FROM orders WHERE id = 123" ,"data": [[123]]}
5256
{"name": "day_in_func", "sql": "SELECT id FROM orders WHERE LOWER(LOWER(LOWER(day))) <> '' AND id = 123", "errors": [], "data": [[123]]}
53-
{"name": "is_null", "sql": "SELECT id FROM orders WHERE day IS NOT NULL AND id = 123", "errors": [], ",data": [[123]]}
54-
{"name": "is_null_static_exp", "sql": "SELECT id FROM orders WHERE NULL IS NULL AND id = 123", "errors": ["Static expression is not allowed: NULL IS NULL"], ",data": [[123]]}
57+
{"name": "is_null", "sql": "SELECT id FROM orders WHERE day IS NOT NULL AND id = 123", "errors": []}
58+
{"name": "is_null_static_exp", "sql": "SELECT id FROM orders WHERE NULL IS NULL AND id = 123", "errors": ["Static expression is not allowed: NULL IS NULL"], "fix": "SELECT id FROM orders WHERE id = 123", "data": [[123]]}
5559
{"name": "not_op", "sql": "SELECT id FROM orders WHERE NOT id = 123", "errors": ["Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE (NOT id = 123) AND id = 123", "data": []}
5660
{"name": "delete_op", "sql": "DELETE FROM orders", "errors": ["DELETE statement is not allowed"]}
5761
{"name": "drop_op", "sql": "DROP orders", "errors": ["DROP statement is not allowed"]}

test/test_sql_guard_unit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_orders_from_file_ai(self, test_name, config, cnn, ai_tests):
9696
_test_sql(test["sql"], config, set(test.get("errors", [])),
9797
test.get("fix"), cnn=cnn, data=test.get("data"))
9898

99-
@pytest.mark.parametrize("test_name", ["no_from_sub_select_lateral"])
99+
@pytest.mark.parametrize("test_name", ["day_between_static_exp"])
100100
def test_by_name(self, test_name, config, cnn, tests):
101101
"""Test by name. Use it to run a single test from tests/ai_tests by name"""
102102
test = tests[test_name]

0 commit comments

Comments
 (0)