Skip to content

Commit 8babb68

Browse files
Added update value test case and added the error in sql_data_guard.py
1 parent 4b10cf1 commit 8babb68

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/sql_data_guard/sql_data_guard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
9494
if parsed:
9595
if isinstance(parsed, expr.Command):
9696
result.add_error(f"{parsed.name} statement is not allowed", False, 0.9)
97-
elif isinstance(parsed, expr.Delete) or isinstance(parsed, expr.Insert):
97+
elif isinstance(parsed, expr.Delete) or isinstance(parsed, expr.Insert) or isinstance(parsed, expr.Update):
9898
result.add_error(
9999
f"{parsed.key.upper()} statement is not allowed", False, 0.9
100100
)

test/test_sql_guard_temp_unit.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,28 @@
99

1010
from sql_data_guard import verify_sql
1111

12+
def _test_sql(sql: str, config: dict, errors: Set[str] = None, fix: str = None, dialect: str = "sqlite",
13+
cnn: Connection = None, data: list = None):
14+
result = verify_sql(sql, config, dialect)
15+
if errors is None:
16+
assert result["errors"] == set()
17+
else:
18+
assert set(result["errors"]) == set(errors)
19+
if len(result["errors"]) > 0:
20+
assert result["risk"] > 0
21+
else:
22+
assert result["risk"] == 0
23+
if fix is None:
24+
assert result.get("fixed") is None
25+
sql_to_use = sql
26+
else:
27+
assert result["fixed"] == fix
28+
sql_to_use = result["fixed"]
29+
if cnn and data:
30+
fetched_data = cnn.execute(sql_to_use).fetchall()
31+
if data is not None:
32+
assert fetched_data == [tuple(row) for row in data]
33+
1234
class TestInvalidQueries:
1335

1436
@pytest.fixture(scope="class")
@@ -136,6 +158,23 @@ def test_missing_restriction(self, config, cnn):
136158
cursor.execute(result["fixed"])
137159
assert cursor.fetchall() == [(324, "prod1")]
138160

161+
def test_using_cnn(self, config,cnn):
162+
cursor = cnn.cursor()
163+
sql = "SELECT id, prod_name FROM products1 WHERE id = 324 and access = 'granted' "
164+
cursor.execute(sql)
165+
res = cursor.fetchall()
166+
expected = [(324, 'prod1')]
167+
assert res == expected
168+
res = verify_sql(sql, config)
169+
assert not res['allowed'], res
170+
cursor.execute(res['fixed'])
171+
assert cursor.fetchall() == [(324, "prod1")]
172+
173+
def test_update_value(self,config):
174+
res = verify_sql("Update products1 set id = 224 where id = 324",config)
175+
assert res['allowed'] == False, res
176+
assert "UPDATE statement is not allowed" in res['errors']
177+
139178
class TestJoins:
140179

141180
@pytest.fixture(scope="class")

0 commit comments

Comments
 (0)