Skip to content

Commit 8891580

Browse files
Added array_col, cross join, unnest tests
1 parent f890117 commit 8891580

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

test/test_sql_guard_modified_unit.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,32 @@ def cnn(self):
270270
"INSERT INTO products1 VALUES (445, 'prod3', 'shipped', 'granted', '28-02-2025', 'c3')"
271271
)
272272

273+
274+
# Trying to do array_col
275+
conn.execute(
276+
"""
277+
CREATE TABLE orders_db.products2 (
278+
id INT,
279+
prod_name TEXT,
280+
deliver TEXT,
281+
access TEXT,
282+
date TEXT,
283+
cust_id TEXT,
284+
category TEXT -- JSON formatted array column
285+
)"""
286+
)
287+
288+
# Insert values into products1 table (JSON formatted array)
289+
conn.execute(
290+
"INSERT INTO products2 VALUES (324, 'prod1', 'delivered', 'granted', '27-02-2025', 'c1', '[""electronics"", ""fashion""]')"
291+
)
292+
conn.execute(
293+
"INSERT INTO products2 VALUES (435, 'prod2', 'delayed', 'pending', '02-03-2025', 'c2', '[""books""]')"
294+
)
295+
conn.execute(
296+
"INSERT INTO products2 VALUES (445, 'prod3', 'shipped', 'granted', '28-02-2025', 'c3', '[""sports"", ""toys""]')"
297+
)
298+
273299
# Creating customers table
274300
conn.execute(
275301
"""
@@ -297,12 +323,23 @@ def config(self) -> dict:
297323
"columns": ["id", "prod_name", "category"],
298324
"restrictions": [{"column": "id", "value": 324}],
299325
},
326+
{
327+
"table_name": "products2",
328+
"database_name": "orders_db",
329+
"columns": ["id", "prod_name", "category"], # category stored as JSON
330+
"restrictions": [{"column": "id", "value": 324}],
331+
},
300332
{
301333
"table_name": "customers",
302334
"database_name": "orders_db",
303335
"columns": ["cust_id", "cust_name", "access"],
304336
"restrictions": [{"column": "access", "value": "restricted"}],
305337
},
338+
{
339+
"table_name": "highlights",
340+
"database_name": "countdb",
341+
"columns": ["vals", "anomalies","id"],
342+
}
306343
]
307344
}
308345

@@ -447,3 +484,57 @@ def test_distinct_and_group_by_missing_restriction(self, config, cnn):
447484
cnn=cnn,
448485
data=[(1, "prod1")],
449486
)
487+
488+
def test_array_col(self,config,cnn):
489+
sql = """
490+
SELECT id, prod_name FROM products2
491+
WHERE (category LIKE '%electronics%') AND id = 324
492+
"""
493+
res = verify_sql(sql, config)
494+
assert res["allowed"] == True, res
495+
print(cnn.execute(sql).fetchall())
496+
assert cnn.execute(sql).fetchall() == [(324, "prod1")]
497+
498+
def test_cross_join_alias(self,config,cnn):
499+
sql = """SELECT p1.id, p2.id FROM products1 AS p1
500+
CROSS JOIN products1 AS p2 WHERE p1.id = 324 AND p2.id = 324"""
501+
res = verify_sql(sql,config)
502+
assert res["allowed"] == True, res
503+
print(cnn.execute(sql).fetchall())
504+
505+
def test_self_join(self,config,cnn):
506+
sql = """SELECT p1.id, p2.id from products1 as p1
507+
inner join products1 as p2 on p1.id = p2.id WHERE p1.id = 324 and p2.id = 324"""
508+
res = verify_sql(sql, config)
509+
assert res["allowed"] == True, res
510+
print(cnn.execute(sql).fetchall())
511+
512+
def test_customers_restriction(self,config):
513+
sql = "SELECT cust_id, cust_name FROM customers WHERE (cust_id = 'c1') AND access = 'restricted'"
514+
res = verify_sql(sql, config)
515+
assert res["allowed"] == True, res
516+
517+
def test_json_field_products1(self,config,cnn):
518+
sql = "SELECT json_extract(category, '$[0]') FROM products2 WHERE id = 324"
519+
res = verify_sql(sql, config)
520+
assert res["allowed"] == True, res
521+
522+
def test_unnest_using_trino_array_val_cross_join(self,config):
523+
verify_sql_test('''SELECT val FROM (VALUES (ARRAY[1, 2, 3]))
524+
AS highlights(vals) CROSS JOIN UNNEST(vals) AS t(val)''',
525+
config,dialect="trino")
526+
527+
def test_unnest_using_trino_insert(self,config):
528+
verify_sql_test("INSERT INTO highlights VALUES (1, ARRAY[10, 20, 30])",config,
529+
dialect="trino", errors={'INSERT statement is not allowed'})
530+
531+
def test_unnest_using_trino_cross_join(self,config):
532+
verify_sql_test("SELECT t.val FROM highlights CROSS JOIN UNNEST(vals) AS t(val)",config
533+
,dialect="trino")
534+
535+
def test_unnest_using_trino_multi_col_alias(self,config):
536+
verify_sql_test("SELECT t.val, h.id FROM highlights AS h CROSS JOIN UNNEST(h.vals) AS t(val)",
537+
config,dialect="trino")
538+
539+
def test_unnest_using_trino_no_alias(self,config):
540+
verify_sql_test("SELECT anomalies from highlights CROSS JOIN UNNEST(vals)",config,dialect="trino")

0 commit comments

Comments
 (0)