Skip to content

Collect stats about rate limited requests #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import dev.aikido.agent_api.ratelimiting.ShouldRateLimit;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.storage.ServiceConfiguration;
import dev.aikido.agent_api.storage.routes.RoutesStore;
import dev.aikido.agent_api.storage.statistics.StatisticsStore;

public final class ShouldBlockRequest {
private ShouldBlockRequest() {
Expand Down Expand Up @@ -34,6 +36,10 @@ public static ShouldBlockRequestResult shouldBlockRequest() {
context.getRouteMetadata(), context.getUser(), context.getRemoteAddress()
);
if (rateLimitDecision.block()) {
// increment rate-limiting stats both globally and on the route :
StatisticsStore.incrementRateLimited();
RoutesStore.addRouteRateLimitedCount(context.getRouteMetadata());

BlockedRequestResult blockedRequestResult = new BlockedRequestResult(
"ratelimited", rateLimitDecision.trigger(), context.getRemoteAddress()
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package dev.aikido.agent_api.storage.routes;

import com.google.gson.*;
import dev.aikido.agent_api.api_discovery.APISpec;
import dev.aikido.agent_api.context.RouteMetadata;

import java.lang.reflect.Type;

import static dev.aikido.agent_api.api_discovery.APISpecMerger.mergeAPISpecs;

public class RouteEntry {
final String method;
final String path;
private int hits;
private int rateLimitedCount;
private APISpec apispec;

public RouteEntry(String method, String path) {
Expand All @@ -32,6 +30,13 @@ public int getHits() {
return hits;
}

public void incrementRateLimitCount() {
rateLimitedCount++;
}

public int getRateLimitCount() {
return rateLimitedCount;
}
public void updateApiSpec(APISpec newApiSpec) {
this.apispec = mergeAPISpecs(newApiSpec, this.apispec);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,30 @@ public Routes() {
this(1000); // Default max size
}

private void initializeRoute(RouteMetadata routeMetadata) {
private void ensureRoute(RouteMetadata routeMetadata) {
manageRoutesSize();
String key = routeToKey(routeMetadata);
routes.put(key, new RouteEntry(routeMetadata));
if(!routes.containsKey(key)) {
routes.put(key, new RouteEntry(routeMetadata));
}
}

public void incrementRoute(RouteMetadata routeMetadata) {
String key = routeToKey(routeMetadata);
if (!routes.containsKey(key)) {
// if the route does not yet exist, create it.
initializeRoute(routeMetadata);
}
RouteEntry route = routes.get(key);
ensureRoute(routeMetadata);
RouteEntry route = this.get(routeMetadata);
if (route != null) {
route.incrementHits();
}
}

public void incrementRateLimitCount(RouteMetadata routeMetadata) {
ensureRoute(routeMetadata);
RouteEntry route = this.get(routeMetadata);
if (route != null) {
route.incrementRateLimitCount();
}
}

public RouteEntry get(RouteMetadata routeMetadata) {
String key = routeToKey(routeMetadata);
return routes.get(key);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@
}
}

public static void addRouteRateLimitedCount(RouteMetadata routeMetadata) {
mutex.lock();
try {
routes.incrementRateLimitCount(routeMetadata);
} catch (Throwable e) {
logger.debug("Error occurred incrementing route rate limit count: %s", e.getMessage());

Check warning on line 67 in agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RoutesStore.java

View check run for this annotation

Codecov / codecov/patch

agent_api/src/main/java/dev/aikido/agent_api/storage/routes/RoutesStore.java#L66-L67

Added lines #L66 - L67 were not covered by tests
} finally {
mutex.unlock();
}
}

public static void clear() {
mutex.lock();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ public class Statistics {
private final Map<String, Integer> ipAddressMatches = new HashMap<>();
private final Map<String, Integer> userAgentMatches = new HashMap<>();
private int totalHits;
private final int aborted; // We don't use the "aborted" field right now
private int rateLimited;
private int attacksDetected;
private int attacksBlocked;
private long startedAt;

public Statistics(int totalHits, int attacksDetected, int attacksBlocked) {
this.totalHits = totalHits;
this.attacksDetected = attacksDetected;
this.attacksBlocked = attacksBlocked;
this.startedAt = UnixTimeMS.getUnixTimeMS();
}

public Statistics() {
this(0, 0, 0);
this.totalHits = 0;
this.rateLimited = 0;
this.aborted = 0;
this.attacksDetected = 0;
this.attacksBlocked = 0;
this.startedAt = UnixTimeMS.getUnixTimeMS();
}


Expand All @@ -35,6 +35,14 @@ public int getTotalHits() {
return totalHits;
}

public void incrementRateLimited() {
rateLimited += 1;
}

public int getRateLimited() {
return rateLimited;
}


// attack stats
public void incrementAttacksDetected(String operation) {
Expand Down Expand Up @@ -104,8 +112,7 @@ public void addMatchToUserAgents(String key) {
public StatsRecord getRecord() {
long endedAt = UnixTimeMS.getUnixTimeMS();
return new StatsRecord(this.startedAt, endedAt, new StatsRequestsRecord(
/* total */ totalHits,
/* aborted */ 0, // Unknown statistic, default to 0,
totalHits, aborted, rateLimited,
/* attacksDetected */ Map.of(
"total", attacksDetected,
"blocked", attacksBlocked
Expand All @@ -118,6 +125,7 @@ public StatsRecord getRecord() {

public void clear() {
this.totalHits = 0;
this.rateLimited = 0;
this.attacksBlocked = 0;
this.attacksDetected = 0;
this.startedAt = UnixTimeMS.getUnixTimeMS();
Expand All @@ -127,7 +135,8 @@ public void clear() {
}

// Stats records for sending out the heartbeat :
public record StatsRequestsRecord(long total, long aborted, Map<String, Integer> attacksDetected) {
public record StatsRequestsRecord(long total, long aborted, long rateLimited,
Map<String, Integer> attacksDetected) {
}

public record StatsRecord(long startedAt, long endedAt, StatsRequestsRecord requests,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ public static void incrementHits() {
}
}

public static void incrementRateLimited() {
mutex.lock();
try {
stats.incrementRateLimited();
} finally {
mutex.unlock();
}
}

public static void incrementAttacksDetected(String operation) {
mutex.lock();
try {
Expand Down
31 changes: 27 additions & 4 deletions agent_api/src/test/java/ShouldBlockRequestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.context.User;
import dev.aikido.agent_api.storage.RateLimiterStore;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.storage.routes.RoutesStore;
import dev.aikido.agent_api.storage.statistics.StatisticsStore;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -38,12 +41,18 @@ public SampleContextObject() {
public static void clean() {
Context.set(null);
ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse);
StatisticsStore.clear();
RoutesStore.clear();
RateLimiterStore.clear();
};

@AfterEach
public void tearDown() throws SQLException {
Context.set(null);
ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse);
StatisticsStore.clear();
RoutesStore.clear();
RateLimiterStore.clear();
}

@Test
Expand All @@ -59,6 +68,7 @@ public void testNoContext() throws SQLException {
// Test with thread cache not set :
var res2 = ShouldBlockRequest.shouldBlockRequest();
assertFalse(res2.block());
assertEquals(0, StatisticsStore.getStatsRecord().requests().rateLimited());
}

@Test
Expand Down Expand Up @@ -112,7 +122,8 @@ public void testUserSet() throws SQLException {

@Test
public void testEndpointsExistButNoMatch() throws SQLException {
Context.set(null);
ContextObject ctx = new SampleContextObject();
Context.set(ctx);
setEmptyConfigWithEndpointList(List.of(
new Endpoint("POST", "/api2/*", 1, 1000, Collections.emptyList(), false, false, false)
));
Expand All @@ -121,7 +132,6 @@ public void testEndpointsExistButNoMatch() throws SQLException {
var res1 = ShouldBlockRequest.shouldBlockRequest();
assertFalse(res1.block());

Context.set(null);
setEmptyConfigWithEndpointList(List.of(
new Endpoint("POST", "/api2/*", 1, 1000, Collections.emptyList(), false, false, true)
));
Expand All @@ -133,7 +143,8 @@ public void testEndpointsExistButNoMatch() throws SQLException {

@Test
public void testEndpointsExistWithMatch() throws SQLException {
Context.set(null);
ContextObject ctx = new SampleContextObject();
Context.set(ctx);
setEmptyConfigWithEndpointList(List.of(
new Endpoint("GET", "/api/*", 1, 1000, Collections.emptyList(), false, false, false)
));
Expand All @@ -142,14 +153,26 @@ public void testEndpointsExistWithMatch() throws SQLException {
var res1 = ShouldBlockRequest.shouldBlockRequest();
assertFalse(res1.block());

Context.set(null);
setEmptyConfigWithEndpointList(List.of(
new Endpoint("GET", "/api/*", 1, 1000, Collections.emptyList(), false, false, true)
));

// Test with match & rate-limiting enabled :
var res2 = ShouldBlockRequest.shouldBlockRequest();
assertFalse(res2.block());
assertEquals(0, StatisticsStore.getStatsRecord().requests().rateLimited());


var res3 = ShouldBlockRequest.shouldBlockRequest();
var res4 = ShouldBlockRequest.shouldBlockRequest();
assertTrue(res3.block());
assertTrue(res4.block());
assertEquals("ip", res3.data().trigger());
assertEquals("192.168.1.1", res3.data().ip());
assertEquals("ratelimited", res3.data().type());
assertEquals(2, StatisticsStore.getStatsRecord().requests().rateLimited());
assertEquals(2, RoutesStore.getRoutesAsList()[0].getRateLimitCount());

}

@Test
Expand Down
15 changes: 14 additions & 1 deletion agent_api/src/test/java/storage/RouteEntryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,22 @@ public void testGsonWithoutSerializer() throws IOException {
Gson gson = new Gson();
String json = gson.toJson(route1);
assertEquals(
"{\"method\":\"GET\",\"path\":\"/api/1\",\"hits\":0,\"apispec\":{\"body\":{\"schema\":{\"type\":\"object\",\"properties\":{\"oldProp\":{\"type\":\"string\",\"optional\":false}},\"optional\":false},\"type\":\"oldType\"},\"auth\":[{\"type\":\"apiKey\"}]}}",
"{\"method\":\"GET\",\"path\":\"/api/1\",\"hits\":0,\"rateLimitedCount\":0,\"apispec\":{\"body\":{\"schema\":{\"type\":\"object\",\"properties\":{\"oldProp\":{\"type\":\"string\",\"optional\":false}},\"optional\":false},\"type\":\"oldType\"},\"auth\":[{\"type\":\"apiKey\"}]}}",
json
);
}

@Test
public void testIncrementRateLimitedCount() {
// Initial count should be 0
assertEquals(0, route1.getRateLimitCount());

// Increment the rate limited count
route1.incrementRateLimitCount();
assertEquals(1, route1.getRateLimitCount());

// Increment again
route1.incrementRateLimitCount();
assertEquals(2, route1.getRateLimitCount());
}
}
14 changes: 14 additions & 0 deletions agent_api/src/test/java/storage/RoutesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ void testIncrementNonExistentRoute() {
assertEquals(1, routes.size());
}

@Test
void testIncrementRouteRateLimitCount() {
routes.incrementRateLimitCount(routeMetadata1);
RouteEntry entry = routes.get(routeMetadata1);
assertNotNull(entry);
assertEquals(1, entry.getRateLimitCount());
}

@Test
void testIncrementNonExistentRouteRateLimit() {
routes.incrementRateLimitCount(routeMetadata1);
assertEquals(1, routes.size());
}

@Test
void testManageRoutesSize() {
routes.incrementRoute(routeMetadata1);
Expand Down
23 changes: 17 additions & 6 deletions agent_api/src/test/java/storage/StatisticsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ public void testClear() {
stats.incrementAttacksDetected("test2");
stats.incrementAttacksDetected("test1");
stats.incrementAttacksDetected("test1");
stats.incrementRateLimited();
assertEquals(3, stats.getAttacksDetected());
assertEquals(2, stats.getAttacksBlocked());
assertEquals(20, stats.getTotalHits());
assertEquals(1, stats.getRateLimited());
assertEquals(2, stats.getOperations().get("test1").getAttacksDetected().get("total"));
assertEquals(1, stats.getOperations().get("test1").getAttacksDetected().get("blocked"));

Expand All @@ -47,20 +49,29 @@ public void testClear() {
assertEquals(0, stats.getAttacksBlocked());
assertEquals(0, stats.getAttacksDetected());
assertEquals(0, stats.getTotalHits());

assertEquals(0, stats.getRateLimited());
}

@Test
public void testConstructor() {
Statistics stats2 = new Statistics(100, 5, 1);
assertEquals(100, stats2.getTotalHits());
assertEquals(5, stats2.getAttacksDetected());
assertEquals(1, stats2.getAttacksBlocked());
Statistics stats2 = new Statistics();
assertEquals(0, stats2.getTotalHits());
assertEquals(0, stats2.getRateLimited());
assertEquals(0, stats2.getAttacksDetected());
assertEquals(0, stats2.getAttacksBlocked());
}

@Test
public void testStatsRecord() {
Statistics stats2 = new Statistics(100, 5, 1);
Statistics stats2 = new Statistics();
stats2.incrementTotalHits(100);
stats2.incrementAttacksDetected("op2");
stats2.incrementAttacksDetected("op2");
stats2.incrementAttacksDetected("op2");
stats2.incrementAttacksDetected("op2");
stats2.incrementAttacksDetected("op2");
stats2.incrementAttacksBlocked("op2");

stats2.registerCall("operation1", OperationKind.FS_OP);
Statistics.StatsRecord statsRecord = stats2.getRecord();
assertEquals(5, statsRecord.requests().attacksDetected().get("total"));
Expand Down