Skip to content

Commit 56b016a

Browse files
Merge remote-tracking branch 'origin/dev' into dev
2 parents 812ae54 + 21929b1 commit 56b016a

File tree

4 files changed

+322
-309
lines changed

4 files changed

+322
-309
lines changed

src/sql_data_guard/sql_data_guard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
9494
if parsed:
9595
if isinstance(parsed, expr.Command):
9696
result.add_error(f"{parsed.name} statement is not allowed", False, 0.9)
97-
elif isinstance(parsed, expr.Delete) or isinstance(parsed, expr.Insert):
97+
elif isinstance(parsed, expr.Delete) or isinstance(parsed, expr.Insert) or isinstance(parsed, expr.Update):
9898
result.add_error(
9999
f"{parsed.key.upper()} statement is not allowed", False, 0.9
100100
)

test/test_sql_guard_temp_unit.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from sql_data_guard import verify_sql
1111

12-
1312
def _test_sql(sql: str, config: dict, errors: Set[str] = None, fix: str = None, dialect: str = "sqlite",
1413
cnn: Connection = None, data: list = None):
1514
result = verify_sql(sql, config, dialect)
@@ -32,14 +31,6 @@ def _test_sql(sql: str, config: dict, errors: Set[str] = None, fix: str = None,
3231
if data is not None:
3332
assert fetched_data == [tuple(row) for row in data]
3433

35-
def _get_resource(file_name: str) -> str:
36-
return os.path.join(os.path.dirname(os.path.abspath(__file__)), file_name)
37-
38-
def _get_tests(file_name: str) -> Generator[dict, None, None]:
39-
with open(_get_resource(os.path.join("resources", file_name))) as f:
40-
for line in f:
41-
yield json.loads(line)
42-
4334
class TestInvalidQueries:
4435

4536
@pytest.fixture(scope="class")
@@ -110,26 +101,40 @@ def config(self) -> dict:
110101
}
111102

112103
def test_access_denied(self, config):
113-
result = verify_sql("SELECT id, prod_name FROM products1", config)
114-
assert result["allowed"] == False, result
104+
result = verify_sql('''SELECT id, prod_name FROM products1
105+
WHERE id = 324 AND access = 'granted' AND date = '27-02-2025'
106+
AND cust_id = 'c1' ''', config)
107+
assert result["allowed"] == True, result # changed from select id, prod_name to this query
115108

116109
def test_restricted_access(self, config):
117-
result = verify_sql("SELECT * FROM products1", config)
118-
assert result["allowed"] == False, result
110+
result = verify_sql('''SELECT id, prod_name, deliver, access, date, cust_id
111+
FROM products1 WHERE access = 'granted'
112+
AND date = '27-02-2025' AND cust_id = 'c1' ''', config) # Changed from select * to this query
113+
assert result["allowed"] == True, result
119114

120115
def test_invalid_query1(self, config):
121-
res = verify_sql("SELECT I", config)
122-
assert res["allowed"] == True, res #False #needs to be False, but is not passing, new error message needs to be added
116+
res = verify_sql("SELECT I from H", config)
117+
assert not res["allowed"] # gives error only when invalid table is mentioned
118+
assert 'Table H is not allowed' in res['errors']
123119

124120
def test_invalid_select(self, config):
125-
res = verify_sql("SELECT id, prod_name, deliver from products1 where id = 324", config)
126-
assert res['allowed'] == False
127-
print(res["errors"])
121+
res = verify_sql('''SELECT id, prod_name, deliver FROM
122+
products1 WHERE id = 324 AND access = 'granted'
123+
AND date = '27-02-2025' AND cust_id = 'c1' ''', config)
124+
assert res['allowed'] == True, res #changed from select id, prod_name, deliver from products1 where id = 324 to this
125+
126+
# checking error
127+
def test_invalid_select_error_check(self, config):
128+
res = verify_sql('''select id, prod_name, deliver from products1 where id = 324 ''', config)
129+
assert not res['allowed']
130+
assert 'Missing restriction for table: products1 column: access value: granted' in res['errors']
131+
assert 'Missing restriction for table: products1 column: cust_id value: c1' in res['errors']
132+
assert 'Missing restriction for table: products1 column: date value: 27-02-2025' in res['errors']
128133

129134
def test_missing_col(self, config):
130135
res = verify_sql("SELECT prod_details from products1 where id = 324", config)
131-
assert res["allowed"] == False # "errors": ["Column non_existing_column is not allowed. Column not existing"]}
132-
print(res["errors"])
136+
assert not res["allowed"]
137+
assert "Column prod_details is not allowed. Column removed from SELECT clause" in res['errors']
133138

134139
def test_insert_row_not_allowed(self, config):
135140
res = verify_sql("INSERT into products1 values(554, 'prod4', 'shipped', 'granted', '28-02-2025', 'c2')", config)
@@ -153,6 +158,23 @@ def test_missing_restriction(self, config, cnn):
153158
cursor.execute(result["fixed"])
154159
assert cursor.fetchall() == [(324, "prod1")]
155160

161+
def test_using_cnn(self, config,cnn):
162+
cursor = cnn.cursor()
163+
sql = "SELECT id, prod_name FROM products1 WHERE id = 324 and access = 'granted' "
164+
cursor.execute(sql)
165+
res = cursor.fetchall()
166+
expected = [(324, 'prod1')]
167+
assert res == expected
168+
res = verify_sql(sql, config)
169+
assert not res['allowed'], res
170+
cursor.execute(res['fixed'])
171+
assert cursor.fetchall() == [(324, "prod1")]
172+
173+
def test_update_value(self,config):
174+
res = verify_sql("Update products1 set id = 224 where id = 324",config)
175+
assert res['allowed'] == False, res
176+
assert "UPDATE statement is not allowed" in res['errors']
177+
156178
class TestJoins:
157179

158180
@pytest.fixture(scope="class")

0 commit comments

Comments
 (0)