Skip to content

Commit d5bdf1d

Browse files
committed
update code formatting - use black
1 parent 9004c88 commit d5bdf1d

File tree

5 files changed

+225
-135
lines changed

5 files changed

+225
-135
lines changed

src/sql_data_guard/rest/sql_data_guard_rest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
app = Flask(__name__)
1010

11-
@app.route('/verify-sql', methods=['POST'])
11+
12+
@app.route("/verify-sql", methods=["POST"])
1213
def _verify_sql():
1314
if not request.is_json:
1415
return jsonify({"error": "Request must be JSON"}), 400
@@ -24,15 +25,15 @@ def _verify_sql():
2425
result["errors"] = list(result["errors"])
2526
return jsonify(result)
2627

28+
2729
def _init_logging():
2830
fileConfig(os.path.join(os.path.dirname(os.path.abspath(__file__)), "logging.conf"))
2931
logging.info("Logging initialized")
3032

3133

32-
if __name__ == '__main__':
34+
if __name__ == "__main__":
3335
_init_logging()
34-
logging.getLogger("werkzeug").setLevel('WARNING')
36+
logging.getLogger("werkzeug").setLevel("WARNING")
3537
port = os.environ.get("APP_PORT", 5000)
3638
logging.info(f"Going to start the app. Port: {port}")
3739
app.run(host="0.0.0.0", port=port)
38-

src/sql_data_guard/sql_data_guard.py

Lines changed: 101 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class _VerificationContext:
1818
_dynamic_tables (Set[str]): Set of dynamic tables found in the query, like sub select and WITH clauses.
1919
_dialect (str): The SQL dialect to use for parsing.
2020
"""
21+
2122
def __init__(self, config: dict, dialect: str):
2223
super().__init__()
2324
self._can_fix = True
@@ -50,7 +51,6 @@ def fixed(self) -> Optional[str]:
5051
def fixed(self, value: Optional[str]):
5152
self._fixed = value
5253

53-
5454
@property
5555
def config(self) -> dict:
5656
return self._config
@@ -95,31 +95,48 @@ def verify_sql(sql: str, config: dict, dialect: str = None) -> dict:
9595
if isinstance(parsed, expr.Command):
9696
result.add_error(f"{parsed.name} statement is not allowed", False, 0.9)
9797
elif isinstance(parsed, expr.Delete) or isinstance(parsed, expr.Insert):
98-
result.add_error(f"{parsed.key.upper()} statement is not allowed", False, 0.9)
98+
result.add_error(
99+
f"{parsed.key.upper()} statement is not allowed", False, 0.9
100+
)
99101
elif isinstance(parsed, expr.Query):
100102
_verify_query_statement(parsed, result)
101103
else:
102104
result.add_error("Could not find a query statement", False, 0.7)
103105
if result.can_fix and len(result.errors) > 0:
104106
result.fixed = parsed.sql()
105-
return { "allowed": len(result.errors) == 0, "errors": result.errors, "fixed": result.fixed, "risk": result.risk}
106-
107-
108-
def _verify_where_clause(context: _VerificationContext, select_statement: expr.Query,
109-
from_tables: List[expr.Table]):
107+
return {
108+
"allowed": len(result.errors) == 0,
109+
"errors": result.errors,
110+
"fixed": result.fixed,
111+
"risk": result.risk,
112+
}
113+
114+
115+
def _verify_where_clause(
116+
context: _VerificationContext,
117+
select_statement: expr.Query,
118+
from_tables: List[expr.Table],
119+
):
110120
_verify_static_expression(select_statement, context)
111121
_verify_restrictions(select_statement, context, from_tables)
112122

113-
def _verify_restrictions(select_statement: expr.Query,
114-
context: _VerificationContext,
115-
from_tables: List[expr.Table]):
123+
124+
def _verify_restrictions(
125+
select_statement: expr.Query,
126+
context: _VerificationContext,
127+
from_tables: List[expr.Table],
128+
):
116129
where_clause = select_statement.find(expr.Where)
117130
if where_clause is None:
118131
where_clause = select_statement.find(expr.Where)
119132
and_exps = []
120133
else:
121134
and_exps = list(_split_to_expressions(where_clause.this, expr.And))
122-
for t in [c_t for c_t in context.config["tables"] if c_t["table_name"] in [t.name for t in from_tables]]:
135+
for t in [
136+
c_t
137+
for c_t in context.config["tables"]
138+
if c_t["table_name"] in [t.name for t in from_tables]
139+
]:
123140
for idx, r in enumerate(t.get("restrictions", [])):
124141
found = False
125142
for sub_exp in and_exps:
@@ -129,18 +146,30 @@ def _verify_restrictions(select_statement: expr.Query,
129146
if not found:
130147
context.add_error(
131148
f"Missing restriction for table: {t['table_name']} column: {r['column']} value: {r['value']}",
132-
True, 0.5)
149+
True,
150+
0.5,
151+
)
133152
value = f"'{r['value']}'" if isinstance(r["value"], str) else r["value"]
134-
new_condition = sqlglot.parse_one(f"{r['column']} = {value}", dialect=context.dialect)
153+
new_condition = sqlglot.parse_one(
154+
f"{r['column']} = {value}", dialect=context.dialect
155+
)
135156
if where_clause is None:
136157
where_clause = expr.Where(this=new_condition)
137158
select_statement.set("where", where_clause)
138159
else:
139-
where_clause = where_clause.replace(expr.Where(this=expr.And(this=expr.paren(where_clause.this),
140-
expression=new_condition)))
141-
142-
143-
def _verify_static_expression(select_statement: expr.Query, context: _VerificationContext) -> bool:
160+
where_clause = where_clause.replace(
161+
expr.Where(
162+
this=expr.And(
163+
this=expr.paren(where_clause.this),
164+
expression=new_condition,
165+
)
166+
)
167+
)
168+
169+
170+
def _verify_static_expression(
171+
select_statement: expr.Query, context: _VerificationContext
172+
) -> bool:
144173
has_static_exp = False
145174
where_clause = select_statement.find(expr.Where)
146175
if where_clause:
@@ -152,6 +181,7 @@ def _verify_static_expression(select_statement: expr.Query, context: _Verificati
152181
simplify(where_clause)
153182
return not has_static_exp
154183

184+
155185
def _has_static_expression(context: _VerificationContext, exp: expr.Expression) -> bool:
156186
if isinstance(exp, expr.Not):
157187
return _has_static_expression(context, exp.this)
@@ -162,7 +192,8 @@ def _has_static_expression(context: _VerificationContext, exp: expr.Expression)
162192
result = _has_static_expression(context, sub_exp)
163193
elif not sub_exp.find(expr.Column):
164194
context.add_error(
165-
f"Static expression is not allowed: {sub_exp.sql()}", True, 0.8)
195+
f"Static expression is not allowed: {sub_exp.sql()}", True, 0.8
196+
)
166197
par = sub_exp.parent
167198
while isinstance(par, expr.Paren):
168199
par = par.parent
@@ -176,15 +207,15 @@ def _has_static_expression(context: _VerificationContext, exp: expr.Expression)
176207

177208
def _verify_restriction(restriction: dict, exp: expr.Expression) -> bool:
178209
"""
179-
Verifies if a given restriction is satisfied within an SQL expression.
210+
Verifies if a given restriction is satisfied within an SQL expression.
180211
181-
Args:
182-
restriction (dict): The restriction to verify, containing 'column' and 'value' keys.
183-
exp (list): The SQL expression to check against the restriction.
212+
Args:
213+
restriction (dict): The restriction to verify, containing 'column' and 'value' keys.
214+
exp (list): The SQL expression to check against the restriction.
184215
185-
Returns:
186-
bool: True if the restriction is satisfied, False otherwise.
187-
"""
216+
Returns:
217+
bool: True if the restriction is satisfied, False otherwise.
218+
"""
188219
if isinstance(exp, expr.Not):
189220
return False
190221
if isinstance(exp, expr.Paren):
@@ -202,8 +233,8 @@ def _verify_restriction(restriction: dict, exp: expr.Expression) -> bool:
202233
return exp.right.this == str(restriction["value"])
203234
return False
204235

205-
def _verify_query_statement(query_statement: expr.Query,
206-
context: _VerificationContext):
236+
237+
def _verify_query_statement(query_statement: expr.Query, context: _VerificationContext):
207238
if isinstance(query_statement, expr.Union):
208239
_verify_query_statement(query_statement.left, context)
209240
_verify_query_statement(query_statement.right, context)
@@ -226,9 +257,11 @@ def _verify_query_statement(query_statement: expr.Query,
226257
return query_statement
227258

228259

229-
def _verify_select_clause(context: _VerificationContext,
230-
select_clause: expr.Query,
231-
from_tables: List[expr.Table]):
260+
def _verify_select_clause(
261+
context: _VerificationContext,
262+
select_clause: expr.Query,
263+
from_tables: List[expr.Table],
264+
):
232265
to_remove = []
233266
for e in select_clause.expressions:
234267
if not _verify_select_clause_element(from_tables, context, e):
@@ -238,8 +271,10 @@ def _verify_select_clause(context: _VerificationContext,
238271
if len(select_clause.expressions) == 0:
239272
context.add_error("No legal elements in SELECT clause", False, 0.5)
240273

241-
def _verify_select_clause_element(from_tables: List[expr.Table], context: _VerificationContext,
242-
e: expr.Expression):
274+
275+
def _verify_select_clause_element(
276+
from_tables: List[expr.Table], context: _VerificationContext, e: expr.Expression
277+
):
243278
if isinstance(e, expr.Column):
244279
if not _verify_col(e, from_tables, context):
245280
return False
@@ -249,7 +284,9 @@ def _verify_select_clause_element(from_tables: List[expr.Table], context: _Verif
249284
for config_t in context.config["tables"]:
250285
if t.name == config_t["table_name"]:
251286
for c in config_t["columns"]:
252-
e.parent.set("expressions", e.parent.expressions + [sqlglot.parse_one(c)])
287+
e.parent.set(
288+
"expressions", e.parent.expressions + [sqlglot.parse_one(c)]
289+
)
253290
return False
254291
elif isinstance(e, expr.Tuple):
255292
result = True
@@ -263,7 +300,10 @@ def _verify_select_clause_element(from_tables: List[expr.Table], context: _Verif
263300
return False
264301
return True
265302

266-
def _verify_col(col: expr.Column, from_tables: List[expr.Table], context: _VerificationContext) -> bool:
303+
304+
def _verify_col(
305+
col: expr.Column, from_tables: List[expr.Table], context: _VerificationContext
306+
) -> bool:
267307
"""
268308
Verifies if a column reference is allowed based on the provided tables and context.
269309
@@ -275,16 +315,25 @@ def _verify_col(col: expr.Column, from_tables: List[expr.Table], context: _Verif
275315
Returns:
276316
bool: True if the column reference is allowed, False otherwise.
277317
"""
278-
if col.table == "sub_select" or col.table != "" and col.table in context.dynamic_tables:
318+
if (
319+
col.table == "sub_select"
320+
or col.table != ""
321+
and col.table in context.dynamic_tables
322+
):
279323
pass
280324
elif not _find_column(col.name, from_tables, context):
281-
context.add_error(f"Column {col.name} is not allowed. Column removed from SELECT clause",
282-
True,0.3)
325+
context.add_error(
326+
f"Column {col.name} is not allowed. Column removed from SELECT clause",
327+
True,
328+
0.3,
329+
)
283330
return False
284331
return True
285332

286333

287-
def _find_column(col_name: str, from_tables: List[expr.Table], result: _VerificationContext) -> bool:
334+
def _find_column(
335+
col_name: str, from_tables: List[expr.Table], result: _VerificationContext
336+
) -> bool:
288337
"""
289338
Finds a column in the given tables based on the provided column name.
290339
@@ -306,16 +355,18 @@ def _find_column(col_name: str, from_tables: List[expr.Table], result: _Verifica
306355
return False
307356

308357

309-
def _get_from_clause_tables(select_clause: expr.Query, context: _VerificationContext) -> List[expr.Table]:
358+
def _get_from_clause_tables(
359+
select_clause: expr.Query, context: _VerificationContext
360+
) -> List[expr.Table]:
310361
"""
311-
Extracts table references from the FROM clause of an SQL query.
362+
Extracts table references from the FROM clause of an SQL query.
312363
313-
Args:
314-
select_clause (dict): The FROM clause of the SQL query.
315-
context (_VerificationContext): The context for verification.
364+
Args:
365+
select_clause (dict): The FROM clause of the SQL query.
366+
context (_VerificationContext): The context for verification.
316367
317-
Returns:
318-
List[_TableRef]: A list of table references to find in the FROM clause.
368+
Returns:
369+
List[_TableRef]: A list of table references to find in the FROM clause.
319370
"""
320371
result = []
321372
from_clause = select_clause.find(expr.From)
@@ -337,13 +388,15 @@ def _get_from_clause_tables(select_clause: expr.Query, context: _VerificationCon
337388
return result
338389

339390

340-
def _split_to_expressions(exp: expr.Expression,
341-
exp_type: Type[expr.Expression]) -> Generator[expr.Expression, None, None]:
391+
def _split_to_expressions(
392+
exp: expr.Expression, exp_type: Type[expr.Expression]
393+
) -> Generator[expr.Expression, None, None]:
342394
if isinstance(exp, exp_type):
343395
yield from exp.flatten()
344396
else:
345397
yield exp
346398

399+
347400
def _find_direct(exp: expr.Expression, exp_type: Type[expr.Expression]):
348401
for child in exp.args.values():
349402
if isinstance(child, exp_type):

test/test_rest_api_unit.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from sql_data_guard.rest import app
44

5+
56
class TestRestAppErrors:
67
def test_verify_sql_method_not_allowed(self):
78
result = app.test_client().get("/verify-sql")
@@ -18,39 +19,49 @@ def test_verify_sql_no_sql(self):
1819
assert result.json == {"error": "Missing 'sql' in request"}
1920

2021
def test_very_sql_no_config(self):
21-
result = app.test_client().post("/verify-sql", json={"sql": "SELECT * FROM my_table"})
22+
result = app.test_client().post(
23+
"/verify-sql", json={"sql": "SELECT * FROM my_table"}
24+
)
2225
assert result.status_code == 400
2326
assert result.json == {"error": "Missing 'config' in request"}
2427

28+
2529
class TestRestAppVerifySql:
2630
@pytest.fixture(scope="class")
2731
def config(self) -> dict:
28-
return { "tables": [
29-
{
30-
"table_name": "orders",
31-
"database_name": "orders_db",
32-
"columns": ["id", "product_name", "account_id", "day"],
33-
"restrictions": [{"column": "id", "value": 123}]
34-
}
35-
]
36-
}
32+
return {
33+
"tables": [
34+
{
35+
"table_name": "orders",
36+
"database_name": "orders_db",
37+
"columns": ["id", "product_name", "account_id", "day"],
38+
"restrictions": [{"column": "id", "value": 123}],
39+
}
40+
]
41+
}
3742

3843
def test_verify_sql(self, config):
39-
result = app.test_client().post("/verify-sql",
40-
json={"sql": "SELECT id FROM orders WHERE id = 123",
41-
"config": config})
44+
result = app.test_client().post(
45+
"/verify-sql",
46+
json={"sql": "SELECT id FROM orders WHERE id = 123", "config": config},
47+
)
4248
assert result.status_code == 200
43-
assert result.json == {'allowed': True, 'errors': [], 'fixed': None, "risk": 0}
44-
49+
assert result.json == {"allowed": True, "errors": [], "fixed": None, "risk": 0}
4550

4651
def test_verify_sql_error(self, config):
47-
result = app.test_client().post("/verify-sql",
48-
json={"sql": "SELECT id, another_col FROM orders WHERE id = 123",
49-
"config": config})
52+
result = app.test_client().post(
53+
"/verify-sql",
54+
json={
55+
"sql": "SELECT id, another_col FROM orders WHERE id = 123",
56+
"config": config,
57+
},
58+
)
5059
assert result.status_code == 200
5160
assert result.json == {
52-
'allowed': False,
53-
'errors': ['Column another_col is not allowed. Column removed from SELECT clause'],
54-
'fixed': 'SELECT id FROM orders WHERE id = 123',
55-
'risk': 0.3
56-
}
61+
"allowed": False,
62+
"errors": [
63+
"Column another_col is not allowed. Column removed from SELECT clause"
64+
],
65+
"fixed": "SELECT id FROM orders WHERE id = 123",
66+
"risk": 0.3,
67+
}

0 commit comments

Comments
 (0)