Skip to content

Commit 199ab5d

Browse files
committed
fix sub queries in where and multiple joins
1 parent c85c0ec commit 199ab5d

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/sql_data_guard/sql_data_guard.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def _verify_query_statement(query_statement: expr.Query, context: VerificationCo
135135
for cte in query_statement.ctes:
136136
_add_table_alias(cte, context)
137137
_verify_query_statement(cte.this, context)
138+
where_clause = query_statement.find(expr.Where)
139+
if where_clause:
140+
for sub in where_clause.find_all(expr.Subquery):
141+
_verify_query_statement(sub.this, context)
138142
from_tables = _get_from_clause_tables(query_statement, context)
139143
for t in from_tables:
140144
found = False
@@ -251,7 +255,7 @@ def _get_from_clause_tables(
251255
"""
252256
result = []
253257
from_clause = select_clause.find(expr.From)
254-
join_clauses = list(select_clause.find_all(expr.Join))
258+
join_clauses = select_clause.args.get("joins", [])
255259
for clause in [from_clause] + join_clauses:
256260
if clause:
257261
for t in find_direct(clause, expr.Table):

test/resources/orders_test.jsonl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,6 @@
7777
{"name": "inner_select_clause_restricted_table2", "sql": "SELECT (SELECT id FROM users LIMIT 1) AS id FROM orders WHERE id = 123", "errors": ["Table users is not allowed"]}
7878
{"name": "inner_select_clause_restricted_table3", "sql": "SELECT (SELECT id FROM users LIMIT 1) FROM orders WHERE id = 123", "errors": ["Table users is not allowed"]}
7979
{"name": "inner_select_clause_restricted_col1", "sql": "SELECT (SELECT col1, id FROM orders WHERE id = 123) FROM orders WHERE id = 123", "errors": ["Column col1 is not allowed. Column removed from SELECT clause"], "fix": "SELECT (SELECT id FROM orders WHERE id = 123) FROM orders WHERE id = 123", "data": [[123]]}
80-
{"name": "multiple_joins1", "sql": "SELECT id FROM orders AS o1 JOIN orders AS o2 JOIN users WHERE o1.id=123 AND o2.id=123", "errors": ["Table users is not allowed"]}
80+
{"name": "multiple_joins1", "sql": "SELECT id FROM orders AS o1 JOIN orders AS o2 JOIN users WHERE o1.id=123 AND o2.id=123", "errors": ["Table users is not allowed"]}
81+
{"name": "sub_query_in_where", "sql": "SELECT id FROM orders WHERE id = 123 AND id IN (SELECT id FROM orders WHERE id = 123)", "data": [[123]]}
82+
{"name": "sub_query_in_where_missing_restriction", "sql": "SELECT id FROM orders WHERE id = 123 AND id IN (SELECT id FROM orders)", "errors": ["Missing restriction for table: orders column: id value: 123"], "fix": "SELECT id FROM orders WHERE id = 123 AND id IN (SELECT id FROM orders WHERE id = 123)", "data": [[123]]}

test/test_sql_guard_unit.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def test_complex_join(self, config, cnn):
252252
COUNT(o.order_id) AS order_count
253253
FROM orders o
254254
JOIN products p ON o.product_id = p.product_id
255+
WHERE o.account_id = 123
255256
GROUP BY o.account_id, p.product_name
256257
),
257258
RankedProducts AS (
@@ -273,11 +274,16 @@ def test_complex_join(self, config, cnn):
273274
JOIN RankedProducts rp ON oc.product_name = rp.product_name
274275
WHERE oc.account_id IN (
275276
-- Filter accounts with at least 2 orders
276-
SELECT account_id FROM orders GROUP BY account_id HAVING COUNT(order_id) >= 2
277+
SELECT account_id FROM orders
278+
WHERE account_id = 123
279+
GROUP BY account_id HAVING COUNT(order_id) >= 2
277280
)
278281
ORDER BY oc.account_id, rp.product_rank;"""
279-
verify_sql_test_data(
280-
sql, config, cnn, [(123, "Laptop", 1, 1), (123, "Smartphone", 1, 1)]
282+
verify_sql_test(
283+
sql,
284+
config,
285+
cnn=cnn,
286+
data=[(123, "Laptop", 1, 1), (123, "Smartphone", 1, 1)],
281287
)
282288

283289

0 commit comments

Comments
 (0)