Skip to content

Commit 66ce7e0

Browse files
committed
Merge branch 'dev' of https://github.com/ThalesGroup/sql-data-guard into dev
2 parents b8ec835 + d4b6d58 commit 66ce7e0

File tree

2 files changed

+104
-52
lines changed

2 files changed

+104
-52
lines changed

src/sql_data_guard/restriction_verification.py

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -103,65 +103,47 @@ def _verify_restriction(
103103
Verifies if a given restriction is satisfied within an SQL expression.
104104
105105
Args:
106-
restriction (dict): The restriction to verify, containing 'column' and 'value' keys.
106+
restriction (dict): The restriction to verify, containing 'column' and 'value' or 'values'.
107107
from_table (Table): The table reference to check the restriction against.
108-
exp (list): The SQL expression to check against the restriction.
108+
exp (Expression): The SQL expression to check against the restriction.
109109
110110
Returns:
111111
bool: True if the restriction is satisfied, False otherwise.
112112
"""
113+
113114
if isinstance(exp, expr.Not):
114115
return False
115116

116117
if isinstance(exp, expr.Paren):
117118
return _verify_restriction(restriction, from_table, exp.this)
118-
if not isinstance(exp.this, expr.Column):
119-
return False
120-
if not exp.this.name == restriction["column"]:
119+
120+
if not isinstance(exp.this, expr.Column) or exp.this.name != restriction["column"]:
121121
return False
122+
122123
if exp.this.table and from_table.alias and exp.this.table != from_table.alias:
123124
return False
124125
if exp.this.table and not from_table.alias and exp.this.table != from_table.name:
125126
return False
126-
if isinstance(exp, expr.EQ) and isinstance(exp.right, expr.Condition):
127-
if isinstance(exp.right, expr.Boolean):
128-
return exp.right.this == restriction["value"]
129-
else:
130-
values = _get_restriction_values(restriction)
131-
return exp.right.this in values
132127

133-
# Check if the expression is a BETWEEN condition
134-
if isinstance(exp, expr.Between):
135-
low = int(exp.args["low"].this) # Extract the lower bound
136-
high = int(exp.args["high"].this) # Extract the upper bound
137-
restriction_low, restriction_high = map(
138-
int, restriction["values"]
139-
) # Get allowed range from restriction
140-
# Return True only if the given range is within the allowed range
141-
return restriction_low <= low and high <= restriction_high
142-
143-
# Check if the expression is a NOT BETWEEN condition (e.g., price NOT BETWEEN 80 AND 150)
144-
if isinstance(exp, expr.Not) and isinstance(exp.this, expr.Between):
145-
low = int(exp.this.args["low"].this) # Extract lower bound
146-
high = int(exp.this.args["high"].this) # Extract upper bound
147-
restriction_low, restriction_high = map(
148-
int, restriction["values"]
149-
) # Convert to int
150-
# NOT BETWEEN should be valid if the range is completely outside the restriction
151-
# Ensures it's fully outside
152-
return (
153-
low < restriction_low or high > restriction_high
154-
) # Ensures it's fully outside the allowed range
155-
156-
# Check if the expression is an IN condition (e.g., price IN (100, 120, 150))
128+
values = _get_restriction_values(restriction) # Get correct restriction values
129+
130+
# Handle IN condition correctly
157131
if isinstance(exp, expr.In):
158-
expr_values = [int(val.this) for val in exp.expressions] # Extract SQL values
159-
restriction_values = [
160-
int(val) for val in restriction["values"]
161-
] # Extract allowed values
132+
expr_values = [str(val.this) for val in exp.expressions]
133+
return any(v in values for v in expr_values)
134+
135+
# Handle EQ (=) condition
136+
if isinstance(exp, expr.EQ) and isinstance(exp.right, expr.Condition):
137+
return str(exp.right.this) in values
162138

163-
return any(v in restriction_values for v in expr_values)
139+
# Handle BETWEEN conditions correctly
140+
if isinstance(exp, expr.Between):
141+
low, high = int(exp.args["low"].this), int(exp.args["high"].this)
142+
if len(values) == 2: # Ensure we have exactly two values
143+
restriction_low, restriction_high = map(int, values)
144+
return restriction_low <= low and high <= restriction_high
164145

146+
# Handle comparison operators (<, >, <=, >=)
165147
def check_comparison_operator(exp1, restriction_, operator):
166148
"""Handles LT (<), GT (>), LTE (<=), and GTE (>=) conditions."""
167149
if not isinstance(exp1, operator):
@@ -179,33 +161,30 @@ def check_comparison_operator(exp1, restriction_, operator):
179161

180162
value = int(exp1.expression.this) # Extract the number after the operator
181163

182-
if "values" in restriction_: # If a range is given (e.g., [80, 150])
183-
low_restriction, high_restriction = map(int, restriction_["values"])
164+
if len(values) == 2: # If a range is given (e.g., [80, 150])
165+
low_restriction, high_restriction = map(int, values)
184166
if operator in [expr.GT, expr.GTE]:
185167
return low_restriction <= value <= high_restriction
186168
return low_restriction <= value # For LT, LTE
187169

188-
else: # If only a single value exists
189-
restriction_value = int(restriction_["value"])
170+
elif len(values) == 1: # If only a single value exists
171+
restriction_value = int(values[0])
190172
return {
191173
expr.GT: value > restriction_value,
192174
expr.GTE: value >= restriction_value,
193175
expr.LT: value < restriction_value,
194176
expr.LTE: value <= restriction_value,
195-
}[
196-
operator
197-
] # Direct lookup, avoids unnecessary `.get(operator, False)`
177+
}[operator]
178+
179+
return False # Default case
198180

199-
# Apply the function to different comparison operators
200181
if any(
201182
check_comparison_operator(exp, restriction, op)
202183
for op in [expr.LT, expr.GT, expr.LTE, expr.GTE]
203184
):
204-
result = True # Assign instead of `return` inside a loop
205-
else:
206-
result = False # Assign explicitly
185+
return True
207186

208-
return result # Single return statement outside the loop
187+
return False
209188

210189

211190
def _get_restriction_values(restriction: dict) -> List[str]:

test/test_sql_guard_curr_unit.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,76 @@ def test_group_by_having_price(self, cnn, config):
628628
cnn=cnn,
629629
data=[("CategoryA", 120)], # Products in CategoryA with price > 100
630630
)
631+
632+
633+
class TestSQLOrderRestrictions:
634+
635+
@pytest.fixture(scope="class")
636+
def cnn(self):
637+
with sqlite3.connect(":memory:") as conn:
638+
# Create orders table
639+
conn.execute(
640+
"""
641+
CREATE TABLE orders (
642+
id INTEGER,
643+
product_name TEXT,
644+
account_id INTEGER
645+
)"""
646+
)
647+
648+
# Insert sample data into orders table
649+
650+
conn.execute(
651+
"""INSERT INTO orders (id, product_name, account_id)
652+
VALUES
653+
(1, 'Product A', 123),
654+
(2, 'Product B', 124),
655+
(3, "Product C", 125)
656+
"""
657+
)
658+
659+
yield conn
660+
661+
@pytest.fixture(scope="class")
662+
def config(self):
663+
# Assuming self._ALLOWED_ACCOUNT_ID is defined
664+
self._ALLOWED_ACCOUNT_ID = 124 # Example value for the allowed account ID
665+
self._TABLE_NAME = "orders" # Define table name
666+
667+
return {
668+
"tables": [
669+
{
670+
"table_name": self._TABLE_NAME,
671+
"columns": ["id", "product_name", "account_id"],
672+
"restrictions": [
673+
{
674+
"column": "account_id",
675+
"value": [
676+
self._ALLOWED_ACCOUNT_ID,
677+
],
678+
} # Restriction without IN
679+
],
680+
}
681+
]
682+
}
683+
684+
def test_in_operator_with_restriction_(self, config, cnn):
685+
sql = """SELECT product_name FROM orders WHERE account_id IN (123, 124, 125)"""
686+
687+
# Modify the config to handle "value" as "values" just for this specific test case
688+
for table in config["tables"]:
689+
for restriction in table["restrictions"]:
690+
if "value" in restriction:
691+
# If 'value' is present, convert it to 'values'
692+
restriction["values"] = restriction["value"]
693+
del restriction["value"] # Remove 'value' key
694+
695+
# Run the verify_sql_test function with the defined SQL query and configuration
696+
verify_sql_test(
697+
sql,
698+
config,
699+
errors=set(), # No errors expected as the restriction matches the IN clause
700+
fix=None, # No fix should be needed
701+
cnn=cnn,
702+
data=[("Product A",), ("Product B",), ("Product C",)],
703+
)

0 commit comments

Comments
 (0)