@@ -103,65 +103,47 @@ def _verify_restriction(
103103 Verifies if a given restriction is satisfied within an SQL expression.
104104
105105 Args:
106- restriction (dict): The restriction to verify, containing 'column' and 'value' keys .
106+ restriction (dict): The restriction to verify, containing 'column' and 'value' or 'values' .
107107 from_table (Table): The table reference to check the restriction against.
108- exp (list ): The SQL expression to check against the restriction.
108+ exp (Expression ): The SQL expression to check against the restriction.
109109
110110 Returns:
111111 bool: True if the restriction is satisfied, False otherwise.
112112 """
113+
113114 if isinstance (exp , expr .Not ):
114115 return False
115116
116117 if isinstance (exp , expr .Paren ):
117118 return _verify_restriction (restriction , from_table , exp .this )
118- if not isinstance (exp .this , expr .Column ):
119- return False
120- if not exp .this .name == restriction ["column" ]:
119+
120+ if not isinstance (exp .this , expr .Column ) or exp .this .name != restriction ["column" ]:
121121 return False
122+
122123 if exp .this .table and from_table .alias and exp .this .table != from_table .alias :
123124 return False
124125 if exp .this .table and not from_table .alias and exp .this .table != from_table .name :
125126 return False
126- if isinstance (exp , expr .EQ ) and isinstance (exp .right , expr .Condition ):
127- if isinstance (exp .right , expr .Boolean ):
128- return exp .right .this == restriction ["value" ]
129- else :
130- values = _get_restriction_values (restriction )
131- return exp .right .this in values
132127
133- # Check if the expression is a BETWEEN condition
134- if isinstance (exp , expr .Between ):
135- low = int (exp .args ["low" ].this ) # Extract the lower bound
136- high = int (exp .args ["high" ].this ) # Extract the upper bound
137- restriction_low , restriction_high = map (
138- int , restriction ["values" ]
139- ) # Get allowed range from restriction
140- # Return True only if the given range is within the allowed range
141- return restriction_low <= low and high <= restriction_high
142-
143- # Check if the expression is a NOT BETWEEN condition (e.g., price NOT BETWEEN 80 AND 150)
144- if isinstance (exp , expr .Not ) and isinstance (exp .this , expr .Between ):
145- low = int (exp .this .args ["low" ].this ) # Extract lower bound
146- high = int (exp .this .args ["high" ].this ) # Extract upper bound
147- restriction_low , restriction_high = map (
148- int , restriction ["values" ]
149- ) # Convert to int
150- # NOT BETWEEN should be valid if the range is completely outside the restriction
151- # Ensures it's fully outside
152- return (
153- low < restriction_low or high > restriction_high
154- ) # Ensures it's fully outside the allowed range
155-
156- # Check if the expression is an IN condition (e.g., price IN (100, 120, 150))
128+ values = _get_restriction_values (restriction ) # Get correct restriction values
129+
130+ # Handle IN condition correctly
157131 if isinstance (exp , expr .In ):
158- expr_values = [int (val .this ) for val in exp .expressions ] # Extract SQL values
159- restriction_values = [
160- int (val ) for val in restriction ["values" ]
161- ] # Extract allowed values
132+ expr_values = [str (val .this ) for val in exp .expressions ]
133+ return any (v in values for v in expr_values )
134+
135+ # Handle EQ (=) condition
136+ if isinstance (exp , expr .EQ ) and isinstance (exp .right , expr .Condition ):
137+ return str (exp .right .this ) in values
162138
163- return any (v in restriction_values for v in expr_values )
139+ # Handle BETWEEN conditions correctly
140+ if isinstance (exp , expr .Between ):
141+ low , high = int (exp .args ["low" ].this ), int (exp .args ["high" ].this )
142+ if len (values ) == 2 : # Ensure we have exactly two values
143+ restriction_low , restriction_high = map (int , values )
144+ return restriction_low <= low and high <= restriction_high
164145
146+ # Handle comparison operators (<, >, <=, >=)
165147 def check_comparison_operator (exp1 , restriction_ , operator ):
166148 """Handles LT (<), GT (>), LTE (<=), and GTE (>=) conditions."""
167149 if not isinstance (exp1 , operator ):
@@ -179,33 +161,30 @@ def check_comparison_operator(exp1, restriction_, operator):
179161
180162 value = int (exp1 .expression .this ) # Extract the number after the operator
181163
182- if " values" in restriction_ : # If a range is given (e.g., [80, 150])
183- low_restriction , high_restriction = map (int , restriction_ [ " values" ] )
164+ if len ( values ) == 2 : # If a range is given (e.g., [80, 150])
165+ low_restriction , high_restriction = map (int , values )
184166 if operator in [expr .GT , expr .GTE ]:
185167 return low_restriction <= value <= high_restriction
186168 return low_restriction <= value # For LT, LTE
187169
188- else : # If only a single value exists
189- restriction_value = int (restriction_ [ "value" ])
170+ elif len ( values ) == 1 : # If only a single value exists
171+ restriction_value = int (values [ 0 ])
190172 return {
191173 expr .GT : value > restriction_value ,
192174 expr .GTE : value >= restriction_value ,
193175 expr .LT : value < restriction_value ,
194176 expr .LTE : value <= restriction_value ,
195- }[
196- operator
197- ] # Direct lookup, avoids unnecessary `.get(operator, False)`
177+ }[operator ]
178+
179+ return False # Default case
198180
199- # Apply the function to different comparison operators
200181 if any (
201182 check_comparison_operator (exp , restriction , op )
202183 for op in [expr .LT , expr .GT , expr .LTE , expr .GTE ]
203184 ):
204- result = True # Assign instead of `return` inside a loop
205- else :
206- result = False # Assign explicitly
185+ return True
207186
208- return result # Single return statement outside the loop
187+ return False
209188
210189
211190def _get_restriction_values (restriction : dict ) -> List [str ]:
0 commit comments