99
1010from sql_data_guard import verify_sql
1111
12-
1312def _test_sql (sql : str , config : dict , errors : Set [str ] = None , fix : str = None , dialect : str = "sqlite" ,
1413 cnn : Connection = None , data : list = None ):
1514 result = verify_sql (sql , config , dialect )
@@ -32,14 +31,6 @@ def _test_sql(sql: str, config: dict, errors: Set[str] = None, fix: str = None,
3231 if data is not None :
3332 assert fetched_data == [tuple (row ) for row in data ]
3433
35- def _get_resource (file_name : str ) -> str :
36- return os .path .join (os .path .dirname (os .path .abspath (__file__ )), file_name )
37-
38- def _get_tests (file_name : str ) -> Generator [dict , None , None ]:
39- with open (_get_resource (os .path .join ("resources" , file_name ))) as f :
40- for line in f :
41- yield json .loads (line )
42-
4334class TestInvalidQueries :
4435
4536 @pytest .fixture (scope = "class" )
@@ -110,26 +101,40 @@ def config(self) -> dict:
110101 }
111102
112103 def test_access_denied (self , config ):
113- result = verify_sql ("SELECT id, prod_name FROM products1" , config )
114- assert result ["allowed" ] == False , result
104+ result = verify_sql ('''SELECT id, prod_name FROM products1
105+ WHERE id = 324 AND access = 'granted' AND date = '27-02-2025'
106+ AND cust_id = 'c1' ''' , config )
107+ assert result ["allowed" ] == True , result # changed from select id, prod_name to this query
115108
116109 def test_restricted_access (self , config ):
117- result = verify_sql ("SELECT * FROM products1" , config )
118- assert result ["allowed" ] == False , result
110+ result = verify_sql ('''SELECT id, prod_name, deliver, access, date, cust_id
111+ FROM products1 WHERE access = 'granted'
112+ AND date = '27-02-2025' AND cust_id = 'c1' ''' , config ) # Changed from select * to this query
113+ assert result ["allowed" ] == True , result
119114
120115 def test_invalid_query1 (self , config ):
121- res = verify_sql ("SELECT I" , config )
122- assert res ["allowed" ] == True , res #False #needs to be False, but is not passing, new error message needs to be added
116+ res = verify_sql ("SELECT I from H" , config )
117+ assert not res ["allowed" ] # gives error only when invalid table is mentioned
118+ assert 'Table H is not allowed' in res ['errors' ]
123119
124120 def test_invalid_select (self , config ):
125- res = verify_sql ("SELECT id, prod_name, deliver from products1 where id = 324" , config )
126- assert res ['allowed' ] == False
127- print (res ["errors" ])
121+ res = verify_sql ('''SELECT id, prod_name, deliver FROM
122+ products1 WHERE id = 324 AND access = 'granted'
123+ AND date = '27-02-2025' AND cust_id = 'c1' ''' , config )
124+ assert res ['allowed' ] == True , res #changed from select id, prod_name, deliver from products1 where id = 324 to this
125+
126+ # checking error
127+ def test_invalid_select_error_check (self , config ):
128+ res = verify_sql ('''select id, prod_name, deliver from products1 where id = 324 ''' , config )
129+ assert not res ['allowed' ]
130+ assert 'Missing restriction for table: products1 column: access value: granted' in res ['errors' ]
131+ assert 'Missing restriction for table: products1 column: cust_id value: c1' in res ['errors' ]
132+ assert 'Missing restriction for table: products1 column: date value: 27-02-2025' in res ['errors' ]
128133
129134 def test_missing_col (self , config ):
130135 res = verify_sql ("SELECT prod_details from products1 where id = 324" , config )
131- assert res ["allowed" ] == False # "errors": ["Column non_existing_column is not allowed. Column not existing"]}
132- print ( res [" errors" ])
136+ assert not res ["allowed" ]
137+ assert "Column prod_details is not allowed. Column removed from SELECT clause" in res [' errors' ]
133138
134139 def test_insert_row_not_allowed (self , config ):
135140 res = verify_sql ("INSERT into products1 values(554, 'prod4', 'shipped', 'granted', '28-02-2025', 'c2')" , config )
@@ -153,6 +158,23 @@ def test_missing_restriction(self, config, cnn):
153158 cursor .execute (result ["fixed" ])
154159 assert cursor .fetchall () == [(324 , "prod1" )]
155160
161+ def test_using_cnn (self , config ,cnn ):
162+ cursor = cnn .cursor ()
163+ sql = "SELECT id, prod_name FROM products1 WHERE id = 324 and access = 'granted' "
164+ cursor .execute (sql )
165+ res = cursor .fetchall ()
166+ expected = [(324 , 'prod1' )]
167+ assert res == expected
168+ res = verify_sql (sql , config )
169+ assert not res ['allowed' ], res
170+ cursor .execute (res ['fixed' ])
171+ assert cursor .fetchall () == [(324 , "prod1" )]
172+
173+ def test_update_value (self ,config ):
174+ res = verify_sql ("Update products1 set id = 224 where id = 324" ,config )
175+ assert res ['allowed' ] == False , res
176+ assert "UPDATE statement is not allowed" in res ['errors' ]
177+
156178class TestJoins :
157179
158180 @pytest .fixture (scope = "class" )
0 commit comments