|
9 | 9 |
|
10 | 10 | from sql_data_guard import verify_sql |
11 | 11 |
|
| 12 | +def _test_sql(sql: str, config: dict, errors: Set[str] = None, fix: str = None, dialect: str = "sqlite", |
| 13 | + cnn: Connection = None, data: list = None): |
| 14 | + result = verify_sql(sql, config, dialect) |
| 15 | + if errors is None: |
| 16 | + assert result["errors"] == set() |
| 17 | + else: |
| 18 | + assert set(result["errors"]) == set(errors) |
| 19 | + if len(result["errors"]) > 0: |
| 20 | + assert result["risk"] > 0 |
| 21 | + else: |
| 22 | + assert result["risk"] == 0 |
| 23 | + if fix is None: |
| 24 | + assert result.get("fixed") is None |
| 25 | + sql_to_use = sql |
| 26 | + else: |
| 27 | + assert result["fixed"] == fix |
| 28 | + sql_to_use = result["fixed"] |
| 29 | + if cnn and data: |
| 30 | + fetched_data = cnn.execute(sql_to_use).fetchall() |
| 31 | + if data is not None: |
| 32 | + assert fetched_data == [tuple(row) for row in data] |
| 33 | + |
12 | 34 | class TestInvalidQueries: |
13 | 35 |
|
14 | 36 | @pytest.fixture(scope="class") |
@@ -136,6 +158,23 @@ def test_missing_restriction(self, config, cnn): |
136 | 158 | cursor.execute(result["fixed"]) |
137 | 159 | assert cursor.fetchall() == [(324, "prod1")] |
138 | 160 |
|
| 161 | + def test_using_cnn(self, config,cnn): |
| 162 | + cursor = cnn.cursor() |
| 163 | + sql = "SELECT id, prod_name FROM products1 WHERE id = 324 and access = 'granted' " |
| 164 | + cursor.execute(sql) |
| 165 | + res = cursor.fetchall() |
| 166 | + expected = [(324, 'prod1')] |
| 167 | + assert res == expected |
| 168 | + res = verify_sql(sql, config) |
| 169 | + assert not res['allowed'], res |
| 170 | + cursor.execute(res['fixed']) |
| 171 | + assert cursor.fetchall() == [(324, "prod1")] |
| 172 | + |
| 173 | + def test_update_value(self,config): |
| 174 | + res = verify_sql("Update products1 set id = 224 where id = 324",config) |
| 175 | + assert res['allowed'] == False, res |
| 176 | + assert "UPDATE statement is not allowed" in res['errors'] |
| 177 | + |
139 | 178 | class TestJoins: |
140 | 179 |
|
141 | 180 | @pytest.fixture(scope="class") |
|
0 commit comments