Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

public final class SkipVulnerabilityScanDecider {
private SkipVulnerabilityScanDecider() {}
public static boolean shouldSkipVulnerabilityScan(ContextObject context) {
public static boolean shouldSkipVulnerabilityScan(ContextObject context, boolean defaultIfNoContext) {
if (context == null) {
return true;
return defaultIfNoContext;
}
if (context.getForcedProtectionOff().isEmpty()) {
ServiceConfiguration config = getConfig();
Expand All @@ -30,4 +30,7 @@ public static boolean shouldSkipVulnerabilityScan(ContextObject context) {
// Get stored forcedProtectionOff value from cache.
return context.getForcedProtectionOff().get();
}
public static boolean shouldSkipVulnerabilityScan(ContextObject context) {
return shouldSkipVulnerabilityScan(context, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Map;

import static dev.aikido.agent_api.helpers.StackTrace.getCurrentStackTrace;
import static dev.aikido.agent_api.vulnerabilities.SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan;
import static dev.aikido.agent_api.vulnerabilities.ssrf.FindHostnameInContext.findHostnameInContext;
import static dev.aikido.agent_api.vulnerabilities.ssrf.IsPrivateIP.containsPrivateIP;
import static dev.aikido.agent_api.vulnerabilities.ssrf.PrivateIPRedirectFinder.isRedirectToPrivateIP;
Expand All @@ -25,7 +26,7 @@ public static Attack run(String hostname, int port, List<String> ipAddresses, St
}

ContextObject context = Context.get();
if(context == null) {
if (shouldSkipVulnerabilityScan(context)) {
return null;
}
FindHostnameInContext.Res attackFindings = findHostnameInContext(hostname, context, port);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.aikido.agent_api.vulnerabilities.ssrf;

import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.vulnerabilities.Attack;
import dev.aikido.agent_api.vulnerabilities.Vulnerabilities;

Expand All @@ -8,6 +10,7 @@
import java.util.Map;

import static dev.aikido.agent_api.helpers.StackTrace.getCurrentStackTrace;
import static dev.aikido.agent_api.vulnerabilities.SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan;
import static dev.aikido.agent_api.vulnerabilities.ssrf.imds.Resolver.resolvesToImdsIp;

public class StoredSSRFDetector {
Expand All @@ -21,6 +24,13 @@ public Attack run(String hostname, List<String> ipAddresses, String operation) {
return null;
}

ContextObject context = Context.get();
// the 2nd param makes it so that if context is not set, we default to false.
// this is necessary for stored SSRF where we don't want an early return even if there's no context.
if (shouldSkipVulnerabilityScan(context, false)) {
return null;
}

return new Attack(
operation,
Vulnerabilities.STORED_SSRF,
Expand Down
3 changes: 3 additions & 0 deletions agent_api/src/test/java/utils/EmptySampleContextObject.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,7 @@ public EmptySampleContextObject(String route, String method, Map<String, List<St
public void setIp(String ip) {
this.remoteAddress = ip;
}
public void setRoute(String route) {
this.route = route;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package vulnerabilities.ssrf;

import dev.aikido.agent_api.background.Endpoint;
import dev.aikido.agent_api.collectors.RedirectCollector;
import dev.aikido.agent_api.collectors.URLCollector;
import dev.aikido.agent_api.context.Context;
Expand All @@ -18,6 +19,7 @@

import static org.junit.jupiter.api.Assertions.*;
import static utils.EmptyAPIResponses.emptyAPIResponse;
import static utils.EmptyAPIResponses.setEmptyConfigWithEndpointList;

public class SSRFDetectorTest {
@BeforeAll
Expand All @@ -34,7 +36,18 @@ private void setContextAndLifecycle(String url) {
Context.set(new EmptySampleContextObject(url));
ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse);
}

private void setContextAndLifecycle(String url, String route) {
ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse);
setEmptyConfigWithEndpointList(List.of(
new Endpoint(
/* method */ "*", /* route */ "/api2/*",
/* rlm params */ 0, 0,
/* Allowed IPs */ List.of(), /* graphql */ false,
/* forceProtectionOff */ true, /* rlm */ false
)
));
Context.set(new EmptySampleContextObject(url, "http://localhost:3000" + route));
}

@Test
@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token")
Expand Down Expand Up @@ -144,4 +157,21 @@ public void testSsrfDetectorWithServiceHostnameInRedirect() throws MalformedURLE

assertNull(attackData);
}

@Test
@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token")
public void testSsrfDetectorForcedProtectionOff() throws MalformedURLException {
// Setup context :
setContextAndLifecycle("http://ssrf-redirects.testssandbox.com/", "/api2/forced-off-route");

URLCollector.report(new URL("http://ssrf-redirects.testssandbox.com/ssrf-test"));
RedirectCollector.report(new URL("http://ssrf-redirects.testssandbox.com/ssrf-test"), new URL("http://localhost"));
Attack attackData = SSRFDetector.run(
"localhost", 80,
List.of("127.0.0.1"),
"test2nd_op"
);

assertNull(attackData);
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,50 @@
package vulnerabilities.ssrf;

import dev.aikido.agent_api.background.Endpoint;
import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.storage.AttackQueue;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.vulnerabilities.Attack;
import dev.aikido.agent_api.vulnerabilities.ssrf.StoredSSRFDetector;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import utils.EmptySampleContextObject;

import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static utils.EmptyAPIResponses.emptyAPIResponse;
import static utils.EmptyAPIResponses.setEmptyConfigWithEndpointList;

class StoredSSRFDetectorTest {

private final StoredSSRFDetector detector = new StoredSSRFDetector();

@BeforeEach
void setUp() throws InterruptedException {
AttackQueue.clear();
setEmptyConfigWithEndpointList(List.of(
new Endpoint(
/* method */ "*", /* route */ "/api2/*",
/* rlm params */ 0, 0,
/* Allowed IPs */ List.of(), /* graphql */ false,
/* forceProtectionOff */ true, /* rlm */ false
),
new Endpoint(
/* method */ "*", /* route */ "/api3/*",
/* rlm params */ 0, 0,
/* Allowed IPs */ List.of(), /* graphql */ false,
/* forceProtectionOff */ false, /* rlm */ false
)
));
}
@AfterEach
void cleanup() {
Context.set(null);
AttackQueue.clear();
ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse);
}

@Test
void run_WhenHostnameIsNull_ReturnsNull() {
Attack result = detector.run(null, List.of("169.254.169.254"), "testOperation");
Expand Down Expand Up @@ -70,6 +105,30 @@ void run_WhenIpIsIpv6ImdsIp_ReturnsAttack() {
assertNull(result.user);
}

@Test
void run_WhenProtectionForcedOff() {
// prepare forced off context
EmptySampleContextObject context1 = new EmptySampleContextObject("", "http://localhost:3000/api2/test/2/4");
context1.setRoute("/api2/test/2/4");
Context.set(context1);

Attack result = detector.run("test.example.com", List.of("fd00:ec2::254"), "testOperation");
assertNull(result);

Context.set(new EmptySampleContextObject());
result = detector.run("test.example.com", List.of("fd00:ec2::254"), "testOperation");
assertNotNull(result);

assertEquals("testOperation", result.operation);
assertEquals("stored_ssrf", result.kind);
assertEquals("test.example.com", result.payload);
assertEquals("test.example.com", result.metadata.get("hostname"));
assertEquals("fd00:ec2::254", result.metadata.get("privateIP"));
assertNull(result.source);
assertEquals("", result.pathToPayload);
assertNull(result.user);
}

@Test
void run_WhenIpIsIpv6ImdsIp_ReturnsAttackNotWhenIpIsHostname() {
Attack result = detector.run("fd00:ec2::254", List.of("fd00:ec2::254"), "testOperation");
Expand Down
Loading