Skip to content

Commit 6d28ceb

Browse files
committed
add extra test cases to attack_wave_detector
1 parent 9c83555 commit 6d28ceb

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

aikido_zen/vulnerabilities/attack_wave_detection/attack_wave_detector_test.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,224 @@ def test_a_web_scanner_with_delays():
123123
assert detector.is_attack_wave(context)
124124

125125

126+
def test_unique_samples_only():
127+
"""Test that only unique samples are stored (non-unique contexts aren't stored)"""
128+
detector = new_attack_wave_detector()
129+
130+
# Create multiple contexts with the same method and URL
131+
context1 = test_utils.generate_context(method="GET")
132+
context2 = test_utils.generate_context(method="GET") # Same as context1
133+
context3 = test_utils.generate_context(method="GET") # Same as context1
134+
135+
with patch(
136+
"aikido_zen.vulnerabilities.attack_wave_detection.attack_wave_detector.is_web_scanner",
137+
return_value=True,
138+
):
139+
# Make enough requests to trigger attack wave (threshold is 6)
140+
for i in range(6):
141+
detector.is_attack_wave(context1)
142+
143+
# Make a few more identical requests with different context objects
144+
for i in range(3):
145+
detector.is_attack_wave(context2)
146+
detector.is_attack_wave(context3)
147+
148+
# Should have only 1 unique sample despite 9 identical requests
149+
samples = detector.get_samples_for_ip(context1.remote_address)
150+
assert len(samples) == 1
151+
assert samples[0]["method"] == "GET"
152+
assert samples[0]["url"] == context1.url
153+
154+
155+
def test_unique_samples_with_different_methods():
156+
"""Test that different methods for the same URL are stored as separate samples"""
157+
detector = new_attack_wave_detector()
158+
159+
# Create contexts with different methods (URL will be the same due to test_utils limitation)
160+
context_get = test_utils.generate_context(method="GET")
161+
context_post = test_utils.generate_context(method="POST")
162+
context_put = test_utils.generate_context(method="PUT")
163+
164+
with patch(
165+
"aikido_zen.vulnerabilities.attack_wave_detection.attack_wave_detector.is_web_scanner",
166+
return_value=True,
167+
):
168+
# Make enough requests to trigger attack wave for each method
169+
for i in range(2): # 2 requests per method = 6 total
170+
detector.is_attack_wave(context_get)
171+
detector.is_attack_wave(context_post)
172+
detector.is_attack_wave(context_put)
173+
174+
# Should have 3 unique samples (one for each method)
175+
samples = detector.get_samples_for_ip(context_get.remote_address)
176+
assert len(samples) == 3
177+
178+
# Verify each method is present
179+
methods = {sample["method"] for sample in samples}
180+
assert methods == {"GET", "POST", "PUT"}
181+
182+
# All should have the same URL (due to test_utils limitation)
183+
for sample in samples:
184+
assert sample["url"] == context_get.url
185+
186+
187+
def test_samples_max_length():
188+
"""Test that samples are limited to maximum 10 samples"""
189+
detector = new_attack_wave_detector()
190+
191+
# Create a helper function to create contexts with different URLs
192+
def create_context_with_url(url, method="GET"):
193+
from aikido_zen.context import Context
194+
from aikido_zen.helpers.headers import Headers
195+
196+
headers = Headers()
197+
return Context(
198+
context_obj={
199+
"remote_address": "1.1.1.1",
200+
"method": method,
201+
"url": url,
202+
"query": {},
203+
"headers": headers,
204+
"body": None,
205+
"cookies": {},
206+
"source": "test",
207+
"route": "/test",
208+
"user": None,
209+
"executed_middleware": False,
210+
"parsed_userinput": {},
211+
}
212+
)
213+
214+
with patch(
215+
"aikido_zen.vulnerabilities.attack_wave_detection.attack_wave_detector.is_web_scanner",
216+
return_value=True,
217+
):
218+
# Create 15 unique contexts (different URLs) with different IPs to avoid cooldown
219+
unique_contexts = []
220+
for i in range(15):
221+
url = f"http://localhost:8080/unique-route-{i}"
222+
context = create_context_with_url(
223+
url, f"METHOD{i % 5}"
224+
) # Cycle through methods
225+
# Set different IP for each context to avoid cooldown
226+
context.remote_address = f"1.1.1.{i+1}"
227+
unique_contexts.append(context)
228+
229+
# Make requests with all unique contexts (each one should trigger attack wave)
230+
for context in unique_contexts:
231+
# Need to make 6 requests per context to trigger attack wave
232+
for i in range(6):
233+
detector.is_attack_wave(context)
234+
235+
# Check samples for the first IP - should have 1 sample
236+
samples = detector.get_samples_for_ip(unique_contexts[0].remote_address)
237+
assert len(samples) == 1
238+
239+
# Verify sample structure
240+
sample = samples[0]
241+
assert set(sample.keys()) == {"method", "url"}
242+
assert sample["method"] == "METHOD0"
243+
assert sample["url"] == "http://localhost:8080/unique-route-0"
244+
245+
246+
def test_samples_structure():
247+
"""Test that samples contain correct structure with method and URL only"""
248+
detector = new_attack_wave_detector()
249+
250+
# Create a context
251+
context = test_utils.generate_context(method="POST")
252+
253+
with patch(
254+
"aikido_zen.vulnerabilities.attack_wave_detection.attack_wave_detector.is_web_scanner",
255+
return_value=True,
256+
):
257+
# Make enough requests to trigger attack wave
258+
for i in range(6):
259+
detector.is_attack_wave(context)
260+
261+
# Get the samples
262+
samples = detector.get_samples_for_ip(context.remote_address)
263+
264+
# Should have 1 sample
265+
assert len(samples) == 1
266+
267+
# Verify sample structure (method and url only, no user_agent or timestamp)
268+
sample = samples[0]
269+
assert set(sample.keys()) == {"method", "url"}
270+
assert sample["method"] == "POST"
271+
assert sample["url"] == context.url
272+
assert "user_agent" not in sample
273+
assert "timestamp" not in sample
274+
275+
276+
def test_mixed_unique_and_duplicate_samples():
277+
"""Test mixed scenario with both unique and duplicate samples"""
278+
detector = new_attack_wave_detector()
279+
280+
# Create a helper function to create contexts with different URLs
281+
def create_context_with_url(url, method="GET"):
282+
from aikido_zen.context import Context
283+
from aikido_zen.helpers.headers import Headers
284+
285+
headers = Headers()
286+
return Context(
287+
context_obj={
288+
"remote_address": "1.1.1.1",
289+
"method": method,
290+
"url": url,
291+
"query": {},
292+
"headers": headers,
293+
"body": None,
294+
"cookies": {},
295+
"source": "test",
296+
"route": "/test",
297+
"user": None,
298+
"executed_middleware": False,
299+
"parsed_userinput": {},
300+
}
301+
)
302+
303+
# Create some unique contexts
304+
context_env = create_context_with_url("http://localhost:8080/.env", "GET")
305+
context_git = create_context_with_url("http://localhost:8080/.git/config", "POST")
306+
context_htaccess = create_context_with_url("http://localhost:8080/.htaccess", "PUT")
307+
308+
with patch(
309+
"aikido_zen.vulnerabilities.attack_wave_detection.attack_wave_detector.is_web_scanner",
310+
return_value=True,
311+
):
312+
# Add some unique samples (each needs 6 requests to trigger attack wave)
313+
# But use different IPs to avoid cooldown
314+
context_env.remote_address = "1.1.1.1"
315+
context_git.remote_address = "1.1.1.2"
316+
context_htaccess.remote_address = "1.1.1.3"
317+
318+
for i in range(6):
319+
detector.is_attack_wave(context_env)
320+
321+
for i in range(6):
322+
detector.is_attack_wave(context_git)
323+
324+
for i in range(6):
325+
detector.is_attack_wave(context_htaccess)
326+
327+
# Add many duplicate requests for the first context (same IP)
328+
for i in range(10):
329+
detector.is_attack_wave(context_env)
330+
331+
# Should have 1 unique sample for the first IP
332+
samples = detector.get_samples_for_ip(context_env.remote_address)
333+
assert len(samples) == 1
334+
335+
# Verify the sample structure
336+
sample = samples[0]
337+
assert set(sample.keys()) == {"method", "url"}
338+
assert sample["method"] == "GET"
339+
assert sample["url"] == "http://localhost:8080/.env"
340+
assert "user_agent" not in sample
341+
assert "timestamp" not in sample
342+
343+
126344
def test_samples_tracking():
127345
"""Test that samples are tracked correctly"""
128346
detector = new_attack_wave_detector()

0 commit comments

Comments
 (0)