@@ -270,6 +270,32 @@ def cnn(self):
270270 "INSERT INTO products1 VALUES (445, 'prod3', 'shipped', 'granted', '28-02-2025', 'c3')"
271271 )
272272
273+
274+ # Trying to do array_col
275+ conn .execute (
276+ """
277+ CREATE TABLE orders_db.products2 (
278+ id INT,
279+ prod_name TEXT,
280+ deliver TEXT,
281+ access TEXT,
282+ date TEXT,
283+ cust_id TEXT,
284+ category TEXT -- JSON formatted array column
285+ )"""
286+ )
287+
288+ # Insert values into products1 table (JSON formatted array)
289+ conn .execute (
290+ "INSERT INTO products2 VALUES (324, 'prod1', 'delivered', 'granted', '27-02-2025', 'c1', '[" "electronics" ", " "fashion" "]')"
291+ )
292+ conn .execute (
293+ "INSERT INTO products2 VALUES (435, 'prod2', 'delayed', 'pending', '02-03-2025', 'c2', '[" "books" "]')"
294+ )
295+ conn .execute (
296+ "INSERT INTO products2 VALUES (445, 'prod3', 'shipped', 'granted', '28-02-2025', 'c3', '[" "sports" ", " "toys" "]')"
297+ )
298+
273299 # Creating customers table
274300 conn .execute (
275301 """
@@ -297,12 +323,23 @@ def config(self) -> dict:
297323 "columns" : ["id" , "prod_name" , "category" ],
298324 "restrictions" : [{"column" : "id" , "value" : 324 }],
299325 },
326+ {
327+ "table_name" : "products2" ,
328+ "database_name" : "orders_db" ,
329+ "columns" : ["id" , "prod_name" , "category" ], # category stored as JSON
330+ "restrictions" : [{"column" : "id" , "value" : 324 }],
331+ },
300332 {
301333 "table_name" : "customers" ,
302334 "database_name" : "orders_db" ,
303335 "columns" : ["cust_id" , "cust_name" , "access" ],
304336 "restrictions" : [{"column" : "access" , "value" : "restricted" }],
305337 },
338+ {
339+ "table_name" : "highlights" ,
340+ "database_name" : "countdb" ,
341+ "columns" : ["vals" , "anomalies" ,"id" ],
342+ }
306343 ]
307344 }
308345
@@ -447,3 +484,57 @@ def test_distinct_and_group_by_missing_restriction(self, config, cnn):
447484 cnn = cnn ,
448485 data = [(1 , "prod1" )],
449486 )
487+
488+ def test_array_col (self ,config ,cnn ):
489+ sql = """
490+ SELECT id, prod_name FROM products2
491+ WHERE (category LIKE '%electronics%') AND id = 324
492+ """
493+ res = verify_sql (sql , config )
494+ assert res ["allowed" ] == True , res
495+ print (cnn .execute (sql ).fetchall ())
496+ assert cnn .execute (sql ).fetchall () == [(324 , "prod1" )]
497+
498+ def test_cross_join_alias (self ,config ,cnn ):
499+ sql = """SELECT p1.id, p2.id FROM products1 AS p1
500+ CROSS JOIN products1 AS p2 WHERE p1.id = 324 AND p2.id = 324"""
501+ res = verify_sql (sql ,config )
502+ assert res ["allowed" ] == True , res
503+ print (cnn .execute (sql ).fetchall ())
504+
505+ def test_self_join (self ,config ,cnn ):
506+ sql = """SELECT p1.id, p2.id from products1 as p1
507+ inner join products1 as p2 on p1.id = p2.id WHERE p1.id = 324 and p2.id = 324"""
508+ res = verify_sql (sql , config )
509+ assert res ["allowed" ] == True , res
510+ print (cnn .execute (sql ).fetchall ())
511+
512+ def test_customers_restriction (self ,config ):
513+ sql = "SELECT cust_id, cust_name FROM customers WHERE (cust_id = 'c1') AND access = 'restricted'"
514+ res = verify_sql (sql , config )
515+ assert res ["allowed" ] == True , res
516+
517+ def test_json_field_products1 (self ,config ,cnn ):
518+ sql = "SELECT json_extract(category, '$[0]') FROM products2 WHERE id = 324"
519+ res = verify_sql (sql , config )
520+ assert res ["allowed" ] == True , res
521+
522+ def test_unnest_using_trino_array_val_cross_join (self ,config ):
523+ verify_sql_test ('''SELECT val FROM (VALUES (ARRAY[1, 2, 3]))
524+ AS highlights(vals) CROSS JOIN UNNEST(vals) AS t(val)''' ,
525+ config ,dialect = "trino" )
526+
527+ def test_unnest_using_trino_insert (self ,config ):
528+ verify_sql_test ("INSERT INTO highlights VALUES (1, ARRAY[10, 20, 30])" ,config ,
529+ dialect = "trino" , errors = {'INSERT statement is not allowed' })
530+
531+ def test_unnest_using_trino_cross_join (self ,config ):
532+ verify_sql_test ("SELECT t.val FROM highlights CROSS JOIN UNNEST(vals) AS t(val)" ,config
533+ ,dialect = "trino" )
534+
535+ def test_unnest_using_trino_multi_col_alias (self ,config ):
536+ verify_sql_test ("SELECT t.val, h.id FROM highlights AS h CROSS JOIN UNNEST(h.vals) AS t(val)" ,
537+ config ,dialect = "trino" )
538+
539+ def test_unnest_using_trino_no_alias (self ,config ):
540+ verify_sql_test ("SELECT anomalies from highlights CROSS JOIN UNNEST(vals)" ,config ,dialect = "trino" )
0 commit comments