Skip to content

Commit 9b4f968

Browse files
committed
add union all support + more tests
1 parent 2392299 commit 9b4f968

File tree

2 files changed

+35
-19
lines changed

2 files changed

+35
-19
lines changed

src/sql_data_guard/sql_data_guard.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,20 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
9191
result.add_error(f"Error parsing sql: {e}", False, 0.9)
9292
parsed = None
9393
if parsed:
94-
if not isinstance(parsed, expr.Select):
95-
result.add_error("Could not find a select statement", False, 0.7)
94+
if isinstance(parsed, expr.Command):
95+
result.add_error(f"{parsed.name} statement is not allowed", False, 0.9)
96+
elif isinstance(parsed, expr.Delete):
97+
result.add_error(f"{parsed.key.upper()} statement is not allowed", False, 0.9)
98+
elif isinstance(parsed, expr.Query):
99+
_verify_query_statement(parsed, result)
96100
else:
97-
_verify_select_statement(parsed, result)
101+
result.add_error("Could not find a query statement", False, 0.7)
98102
if result.can_fix and len(result.errors) > 0:
99103
result.fixed = parsed.sql()
100104
return { "allowed": len(result.errors) == 0, "errors": result.errors, "fixed": result.fixed, "risk": result.risk}
101105

102106

103-
def _verify_where_clause(result: _VerificationContext, select_statement: expr.Select,
107+
def _verify_where_clause(result: _VerificationContext, select_statement: expr.Query,
104108
from_tables: List[expr.Table]):
105109
where_clause = select_statement.find(expr.Where)
106110
if where_clause is None:
@@ -181,12 +185,16 @@ def _verify_restriction(restriction: dict, exp: expr.Expression) -> bool:
181185
return exp.right.this == str(restriction["value"])
182186
return False
183187

184-
def _verify_select_statement(select_statement: expr.Select,
185-
context: _VerificationContext):
186-
for cte in select_statement.ctes:
188+
def _verify_query_statement(query_statement: expr.Query,
189+
context: _VerificationContext):
190+
if isinstance(query_statement, expr.Union):
191+
_verify_query_statement(query_statement.left, context)
192+
_verify_query_statement(query_statement.right, context)
193+
return
194+
for cte in query_statement.ctes:
187195
context.dynamic_tables.add(cte.alias)
188-
_verify_select_statement(cte.this, context)
189-
from_tables = _get_from_clause_tables(select_statement, context)
196+
_verify_query_statement(cte.this, context)
197+
from_tables = _get_from_clause_tables(query_statement, context)
190198
for t in from_tables:
191199
found = False
192200
for config_t in context.config["tables"]:
@@ -195,14 +203,14 @@ def _verify_select_statement(select_statement: expr.Select,
195203
if not found:
196204
context.add_error(f"Table {t.name} is not allowed", False, 1)
197205
if not context.can_fix:
198-
return select_statement
199-
_verify_select_clause(context, select_statement, from_tables)
200-
_verify_where_clause(context, select_statement, from_tables)
201-
return select_statement
206+
return query_statement
207+
_verify_select_clause(context, query_statement, from_tables)
208+
_verify_where_clause(context, query_statement, from_tables)
209+
return query_statement
202210

203211

204212
def _verify_select_clause(context: _VerificationContext,
205-
select_clause: expr.Select,
213+
select_clause: expr.Query,
206214
from_tables: List[expr.Table]):
207215
to_remove = []
208216
for e in select_clause.expressions:
@@ -281,7 +289,7 @@ def _find_column(col_name: str, from_tables: List[expr.Table], result: _Verifica
281289
return False
282290

283291

284-
def _get_from_clause_tables(select_clause: expr.Select, context: _VerificationContext) -> List[expr.Table]:
292+
def _get_from_clause_tables(select_clause: expr.Query, context: _VerificationContext) -> List[expr.Table]:
285293
"""
286294
Extracts table references from the FROM clause of an SQL query.
287295
@@ -305,12 +313,12 @@ def _get_from_clause_tables(select_clause: expr.Select, context: _VerificationCo
305313
for j in _find_direct(clause, expr.Subquery):
306314
if j.alias != "":
307315
context.dynamic_tables.add(j.alias)
308-
_verify_select_statement(j.this, context)
316+
_verify_query_statement(j.this, context)
309317
if join_clause:
310318
for j in _find_direct(clause, expr.Lateral):
311319
if j.alias != "":
312320
context.dynamic_tables.add(j.alias)
313-
_verify_select_statement(j.this.find(expr.Select), context)
321+
_verify_query_statement(j.this.find(expr.Select), context)
314322
return result
315323

316324

test/resources/orders_test.jsonl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
{"name": "two_illegal_tables", "sql": "SELECT col1 FROM users AS u1 JOIN products AS p1", "errors": ["Table users is not allowed", "Table products is not allowed"]}
33
{"name": "select_no_legal_cols", "sql": "SELECT col1, col2 FROM orders WHERE id = 123", "errors": ["Column col1 is not allowed. Column removed from SELECT clause", "Column col2 is not allowed. Column removed from SELECT clause", "No legal elements in SELECT clause"]}
44
{"name": "select_star", "sql": "SELECT * FROM orders WHERE id = 123", "errors": ["SELECT * is not allowed"], "fix": "SELECT id, product_name, account_id, day FROM orders WHERE id = 123", "data": [[123, "product1", 123, "2025-01-01"]]}
5+
{"name": "select_star_with_column", "sql": "SELECT product_name, * FROM orders WHERE id = 123", "errors": ["SELECT * is not allowed"], "fix": "SELECT product_name, id, product_name, account_id, day FROM orders WHERE id = 123", "data": [["product1", 123, "product1", 123, "2025-01-01"]]}
6+
{"name": "select_star_with_column_and_alias", "sql": "SELECT product_name AS \"p_n\", * FROM orders WHERE id = 123", "errors": ["SELECT * is not allowed"], "fix": "SELECT product_name AS \"p_n\", id, product_name, account_id, day FROM orders WHERE id = 123", "data": [["product1", 123, "product1", 123, "2025-01-01"]]}
57
{"name": "two_cols", "sql": "SELECT id, product_name FROM orders WHERE id = 123", "errors": [], "data": [[123, "product1"]]}
68
{"name": "quote_and_alias", "sql": "SELECT \"id\" AS my_id FROM orders WHERE id = 123", "errors": [], "data": [[123]]}
79
{"name": "sql_with_group_by_and_order_by", "sql": "SELECT id FROM orders GROUP BY id ORDER BY id", "errors": ["Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE id = 123 GROUP BY id ORDER BY id", "data": [[123]]}
@@ -50,5 +52,11 @@
5052
{"name": "is_null", "sql": "SELECT id FROM orders WHERE day IS NOT NULL AND id = 123", "errors": [], ",data": [[123]]}
5153
{"name": "is_null_static_exp", "sql": "SELECT id FROM orders WHERE NULL IS NULL AND id = 123", "errors": ["Static expression is not allowed: NULL IS NULL"], ",data": [[123]]}
5254
{"name": "not_op", "sql": "SELECT id FROM orders WHERE NOT id = 123", "errors": ["Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE (NOT id = 123) AND id = 123", "data": []}
53-
{"name": "delete_op", "sql": "DELETE FROM orders", "errors": ["Could not find a select statement"]}
54-
{"name": "drop_op", "sql": "DROP orders", "errors": ["Could not find a select statement"]}
55+
{"name": "delete_op", "sql": "DELETE FROM orders", "errors": ["DELETE statement is not allowed"]}
56+
{"name": "drop_op", "sql": "DROP orders", "errors": ["DROP statement is not allowed"]}
57+
{"name": "json_object", "sql": "SELECT json_object('id', id) FROM orders WHERE id = 123", "data": [["{\"id\":123}"]]}
58+
{"name": "json_object_with_illegal_col", "sql": "SELECT json_object('id', id, 'status', status) FROM orders WHERE id = 123", "errors": ["Column status is not allowed. Column removed from SELECT clause", "No legal elements in SELECT clause"]}
59+
{"name": "json_object_with_illegal_col_fix", "sql": "SELECT id, json_object('id', id, 'status', status) FROM orders WHERE id = 123", "errors": ["Column status is not allowed. Column removed from SELECT clause"], "fix": "SELECT id FROM orders WHERE id = 123", "data": [[123]]}
60+
{"name": "union_all", "sql": "SELECT id FROM orders WHERE id = 123 UNION ALL SELECT id FROM orders WHERE id = 123", "errors": [], "data": [[123], [123]]}
61+
{"name": "union_all_3_parts", "sql": "SELECT id FROM orders WHERE id = 123 UNION ALL SELECT id FROM orders WHERE id = 123 UNION ALL SELECT id FROM orders WHERE id = 123", "errors": [], "data": [[123], [123], [123]]}
62+
{"name": "union_all_missing_restriction", "sql": "SELECT id FROM orders WHERE id = 123 UNION ALL SELECT id FROM orders", "errors": ["Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE id = 123 UNION ALL SELECT id FROM orders WHERE id = 123", "data": [[123], [123]]}

0 commit comments

Comments
 (0)