diff --git a/benchmark/load/petclinic/benchmark.json b/benchmark/load/petclinic/benchmark.json index d569aeb4253..f679608d1f7 100644 --- a/benchmark/load/petclinic/benchmark.json +++ b/benchmark/load/petclinic/benchmark.json @@ -30,6 +30,12 @@ "JAVA_OPTS": "-javaagent:${TRACER} -Ddd.appsec.enabled=true" } }, + "apisec": { + "env": { + "VARIANT": "apisec", + "JAVA_OPTS": "-javaagent:${TRACER} -Ddd.appsec.enabled=true -Ddd.api-security.enabled=true" + } + }, "iast": { "env": { "VARIANT": "iast", diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java index 36349842a77..083219f7887 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java @@ -1,11 +1,11 @@ package com.datadog.appsec; +import com.datadog.appsec.api.security.ApiSecurityProcessor; import com.datadog.appsec.api.security.ApiSecuritySampler; -import com.datadog.appsec.api.security.ApiSecuritySamplerImpl; -import com.datadog.appsec.api.security.AppSecSpanPostProcessor; import com.datadog.appsec.blocking.BlockingServiceImpl; import com.datadog.appsec.config.AppSecConfigService; import com.datadog.appsec.config.AppSecConfigServiceImpl; +import com.datadog.appsec.config.TraceSegmentPostProcessor; import com.datadog.appsec.ddwaf.WAFModule; import com.datadog.appsec.event.EventDispatcher; import com.datadog.appsec.event.ReplaceableEventProducerService; @@ -23,7 +23,6 @@ import datadog.trace.api.telemetry.ProductChange; import datadog.trace.api.telemetry.ProductChangeCollector; import datadog.trace.bootstrap.ActiveSubsystems; -import datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -69,18 +68,6 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s EventDispatcher eventDispatcher = new EventDispatcher(); REPLACEABLE_EVENT_PRODUCER.replaceEventProducerService(eventDispatcher); - ApiSecuritySampler requestSampler; - if (Config.get().isApiSecurityEnabled()) { - requestSampler = new ApiSecuritySamplerImpl(); - // When DD_API_SECURITY_ENABLED=true, ths post-processor is set even when AppSec is inactive. - // This should be low overhead since the post-processor exits early if there's no AppSec - // context. - SpanPostProcessor.Holder.INSTANCE = - new AppSecSpanPostProcessor(requestSampler, REPLACEABLE_EVENT_PRODUCER); - } else { - requestSampler = new ApiSecuritySampler.NoOp(); - } - ConfigurationPoller configurationPoller = sco.configurationPoller(config); // may throw and abort startup APP_SEC_CONFIG_SERVICE = @@ -90,11 +77,15 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s sco.createRemaining(config); + TraceSegmentPostProcessor apiSecurityPostProcessor = + Config.get().isApiSecurityEnabled() + ? new ApiSecurityProcessor(new ApiSecuritySampler(), REPLACEABLE_EVENT_PRODUCER) + : null; GatewayBridge gatewayBridge = new GatewayBridge( gw, REPLACEABLE_EVENT_PRODUCER, - requestSampler, + apiSecurityPostProcessor, APP_SEC_CONFIG_SERVICE.getTraceSegmentPostProcessors()); loadModules(eventDispatcher, sco.monitoring); diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityProcessor.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityProcessor.java new file mode 100644 index 00000000000..add1de3872e --- /dev/null +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityProcessor.java @@ -0,0 +1,72 @@ +package com.datadog.appsec.api.security; + +import com.datadog.appsec.config.TraceSegmentPostProcessor; +import com.datadog.appsec.event.EventProducerService; +import com.datadog.appsec.event.ExpiredSubscriberInfoException; +import com.datadog.appsec.event.data.DataBundle; +import com.datadog.appsec.event.data.KnownAddresses; +import com.datadog.appsec.event.data.SingletonDataBundle; +import com.datadog.appsec.gateway.AppSecRequestContext; +import com.datadog.appsec.gateway.GatewayContext; +import com.datadog.appsec.report.AppSecEvent; +import datadog.trace.api.Config; +import datadog.trace.api.ProductTraceSource; +import datadog.trace.api.internal.TraceSegment; +import datadog.trace.bootstrap.instrumentation.api.Tags; +import java.util.Collection; +import java.util.Collections; +import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ApiSecurityProcessor implements TraceSegmentPostProcessor { + + private static final Logger log = LoggerFactory.getLogger(ApiSecurityProcessor.class); + private final ApiSecuritySampler sampler; + private final EventProducerService producerService; + + public ApiSecurityProcessor(ApiSecuritySampler sampler, EventProducerService producerService) { + this.sampler = sampler; + this.producerService = producerService; + } + + @Override + public void processTraceSegment( + TraceSegment segment, AppSecRequestContext ctx, Collection collectedEvents) { + if (segment == null || ctx == null) { + return; + } + if (!sampler.sample(ctx)) { + log.debug("Request not sampled, skipping API security post-processing"); + return; + } + log.debug("Request sampled, processing API security post-processing"); + extractSchemas(ctx, segment); + } + + private void extractSchemas( + final @Nonnull AppSecRequestContext ctx, final @Nonnull TraceSegment traceSegment) { + final EventProducerService.DataSubscriberInfo sub = + producerService.getDataSubscribers(KnownAddresses.WAF_CONTEXT_PROCESSOR); + if (sub == null || sub.isEmpty()) { + log.debug("No subscribers for schema extraction"); + return; + } + + final DataBundle bundle = + new SingletonDataBundle<>( + KnownAddresses.WAF_CONTEXT_PROCESSOR, Collections.singletonMap("extract-schema", true)); + try { + GatewayContext gwCtx = new GatewayContext(false); + producerService.publishDataEvent(sub, ctx, bundle, gwCtx); + // TODO: Perhaps do this if schemas have actually been extracted (check when committing + // derivatives) + traceSegment.setTagTop(Tags.ASM_KEEP, true); + if (!Config.get().isApmTracingEnabled()) { + traceSegment.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM); + } + } catch (ExpiredSubscriberInfoException e) { + log.debug("Subscriber info expired", e); + } + } +} diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySampler.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySampler.java index 4412a5d6303..8ace5a7c702 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySampler.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySampler.java @@ -1,35 +1,218 @@ package com.datadog.appsec.api.security; import com.datadog.appsec.gateway.AppSecRequestContext; -import javax.annotation.Nonnull; +import datadog.trace.util.AgentTaskScheduler; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; -public interface ApiSecuritySampler { - /** - * Prepare a request context for later sampling decision. This method should be called at request - * end, and is thread-safe. If a request can potentially be sampled, this method will return true. - * If this method returns true, the caller MUST call {@link #releaseOne()} once the context is not - * needed anymore. - */ - boolean preSampleRequest(final @Nonnull AppSecRequestContext ctx); +/** + * Internal map for API Security sampling. See "[RFC-1021] API Security Sampling Algorithm for + * thread-based concurrency". + */ +public class ApiSecuritySampler { - /** Get the final sampling decision. This method is NOT required to be thread-safe. */ - boolean sampleRequest(AppSecRequestContext ctx); + private static final int DEFAULT_MAX_ITEM_COUNT = 4096; + private static final int DEFAULT_INTERVAL_SECONDS = 30; - /** Release one permit for the sampler. This must be called after processing a span. */ - void releaseOne(); + private final MonotonicClock clock; + private final Executor executor; + private final int intervalSeconds; + private final AtomicReference table; + private final AtomicBoolean rebuild = new AtomicBoolean(false); + private final long zero; + private final long maxItemCount; - final class NoOp implements ApiSecuritySampler { - @Override - public boolean preSampleRequest(@Nonnull AppSecRequestContext ctx) { + public ApiSecuritySampler() { + this( + DEFAULT_MAX_ITEM_COUNT, + DEFAULT_INTERVAL_SECONDS, + new Random().nextLong(), + new DefaultMonotonicClock(), + AgentTaskScheduler.INSTANCE); + } + + public ApiSecuritySampler( + final int maxItemCount, + final int intervalSeconds, + final long zero, + final MonotonicClock clock, + Executor executor) { + table = new AtomicReference<>(new Table(maxItemCount)); + this.maxItemCount = maxItemCount; + this.intervalSeconds = intervalSeconds; + this.zero = zero; + this.clock = clock != null ? clock : new DefaultMonotonicClock(); + this.executor = executor != null ? executor : AgentTaskScheduler.INSTANCE; + } + + public boolean sample(AppSecRequestContext ctx) { + final String route = ctx.getRoute(); + if (route == null) { return false; } - - @Override - public boolean sampleRequest(AppSecRequestContext ctx) { + final String method = ctx.getMethod(); + if (method == null) { return false; } + final int statusCode = ctx.getResponseStatus(); + if (statusCode <= 0) { + return false; + } + final long hash = computeApiHash(route, method, statusCode); + return sample(hash); + } + + public boolean sample(long key) { + if (key == 0L) { + key = zero; + } + final int now = clock.now(); + final Table table = this.table.get(); + Table.FindSlotResult findSlotResult; + while (true) { + findSlotResult = table.findSlot(key); + if (!findSlotResult.exists) { + final int newCount = table.count.incrementAndGet(); + if (newCount > maxItemCount && rebuild.compareAndSet(false, true)) { + runRebuild(); + } + if (newCount > maxItemCount * 2) { + table.count.decrementAndGet(); + return false; + } + if (!findSlotResult.entry.key.compareAndSet(0, key)) { + if (findSlotResult.entry.key.get() == key) { + // Another thread just added this entry + return false; + } + // This entry was just claimed for another key, try another slot. + table.count.decrementAndGet(); + continue; + } + final long newEntryData = buildDataEntry(now, now); + if (findSlotResult.entry.data.compareAndSet(0, newEntryData)) { + return true; + } else { + return false; + } + } + break; + } + long curData = findSlotResult.entry.data.get(); + final int stime = getStime(curData); + final int deadline = now - intervalSeconds; + if (stime <= deadline) { + final long newData = buildDataEntry(now, now); + while (!findSlotResult.entry.data.compareAndSet(curData, newData)) { + curData = findSlotResult.entry.data.get(); + if (getStime(curData) == getAtime(curData)) { + // Another thread just issued a keep decision + return false; + } + if (getStime(curData) > now) { + // Another thread is in our fugure, but did not issue a keep decision. + return true; + } + } + return true; + } + final long newData = buildDataEntry(getStime(curData), now); + while (getAtime(curData) < now) { + if (!findSlotResult.entry.data.compareAndSet(curData, newData)) { + curData = findSlotResult.entry.data.get(); + } + } + return false; + } + + private void runRebuild() { + // TODO + } + + private static class Table { + private final Entry[] table; + private final AtomicInteger count = new AtomicInteger(0); + private final int maxItemCount; + + public Table(int maxItemCount) { + this.maxItemCount = maxItemCount; + final int size = 2 * maxItemCount + 1; + table = new Entry[size]; + for (int i = 0; i < size; i++) { + table[i] = new Entry(); + } + } + + public FindSlotResult findSlot(final long key) { + final int startIndex = (int) (key % (2L * maxItemCount)); + int index = startIndex; + do { + final Entry slot = table[index]; + final long slotKey = slot.key.get(); + if (slotKey == key) { + return new FindSlotResult(slot, true); + } else if (slotKey == 0L) { + return new FindSlotResult(slot, false); + } + index++; + if (index >= table.length) { + index = 0; + } + } while (index != startIndex); + return new FindSlotResult(table[(int) (maxItemCount * 2)], false); + } + + static class FindSlotResult { + public final Entry entry; + public final boolean exists; + + public FindSlotResult(final Entry entry, final boolean exists) { + this.entry = entry; + this.exists = exists; + } + } + + static class Entry { + private final AtomicLong key = new AtomicLong(0L); + private final AtomicLong data = new AtomicLong(0L); + } + } + + interface MonotonicClock { + int now(); + } + static class DefaultMonotonicClock implements MonotonicClock { @Override - public void releaseOne() {} + public int now() { + return (int) (System.nanoTime() / 1_000_000); + } + } + + long buildDataEntry(final int stime, final int atime) { + long result = stime; + result <<= 32; + result |= atime & 0xFFFFFFFFL; + return result; + } + + int getStime(final long data) { + return (int) (data >> 32); + } + + int getAtime(final long data) { + return (int) (data & 0xFFFFFFFFL); + } + + private long computeApiHash(final String route, final String method, final int statusCode) { + long result = 17; + result = 31 * result + route.hashCode(); + result = 31 * result + method.hashCode(); + result = 31 * result + statusCode; + return result; } } diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySamplerImpl.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySamplerImpl.java deleted file mode 100644 index c51bd46ef44..00000000000 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecuritySamplerImpl.java +++ /dev/null @@ -1,168 +0,0 @@ -package com.datadog.appsec.api.security; - -import com.datadog.appsec.gateway.AppSecRequestContext; -import datadog.trace.api.Config; -import datadog.trace.api.time.SystemTimeSource; -import datadog.trace.api.time.TimeSource; -import java.util.Deque; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.Semaphore; -import javax.annotation.Nonnull; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ApiSecuritySamplerImpl implements ApiSecuritySampler { - - private static final Logger log = LoggerFactory.getLogger(ApiSecuritySamplerImpl.class); - - /** - * A maximum number of request contexts we'll keep open past the end of request at any given time. - * This will avoid excessive memory usage in case of a high number of concurrent requests, and - * should also prevent memory leaks. - */ - private static final int MAX_POST_PROCESSING_TASKS = 4; - /** Maximum number of entries in the access map. */ - private static final int MAX_SIZE = 4096; - /** Mapping from endpoint hash to last access timestamp in millis. */ - private final ConcurrentHashMap accessMap; - /** Deque of endpoint hashes ordered by access time. Oldest is always first. */ - private final Deque accessDeque; - - private final long expirationTimeInMs; - private final int capacity; - private final TimeSource timeSource; - private final Semaphore counter = new Semaphore(MAX_POST_PROCESSING_TASKS); - - public ApiSecuritySamplerImpl() { - this( - MAX_SIZE, - (long) (Config.get().getApiSecuritySampleDelay() * 1_000), - SystemTimeSource.INSTANCE); - } - - public ApiSecuritySamplerImpl( - int capacity, long expirationTimeInMs, @Nonnull TimeSource timeSource) { - this.capacity = capacity; - this.expirationTimeInMs = expirationTimeInMs; - this.accessMap = new ConcurrentHashMap<>(); - this.accessDeque = new ConcurrentLinkedDeque<>(); - this.timeSource = timeSource; - } - - @Override - public boolean preSampleRequest(final @Nonnull AppSecRequestContext ctx) { - final String route = ctx.getRoute(); - if (route == null) { - return false; - } - final String method = ctx.getMethod(); - if (method == null) { - return false; - } - final int statusCode = ctx.getResponseStatus(); - if (statusCode <= 0) { - return false; - } - long hash = computeApiHash(route, method, statusCode); - ctx.setApiSecurityEndpointHash(hash); - if (!isApiAccessExpired(hash)) { - return false; - } - if (counter.tryAcquire()) { - log.debug("API security sampling is required for this request (presampled)"); - ctx.setKeepOpenForApiSecurityPostProcessing(true); - return true; - } - return false; - } - - /** Get the final sampling decision. This method is NOT thread-safe. */ - @Override - public boolean sampleRequest(AppSecRequestContext ctx) { - if (ctx == null) { - return false; - } - final Long hash = ctx.getApiSecurityEndpointHash(); - if (hash == null) { - // This should never happen, it should have been short-circuited before. - return false; - } - return updateApiAccessIfExpired(hash); - } - - @Override - public void releaseOne() { - counter.release(); - } - - private boolean updateApiAccessIfExpired(final long hash) { - final long currentTime = timeSource.getCurrentTimeMillis(); - - Long lastAccess = accessMap.get(hash); - if (lastAccess != null && currentTime - lastAccess < expirationTimeInMs) { - return false; - } - - if (accessMap.put(hash, currentTime) == null) { - accessDeque.addLast(hash); - // If we added a new entry, we perform purging. - cleanupExpiredEntries(currentTime); - } else { - // This is now the most recently accessed entry. - accessDeque.remove(hash); - accessDeque.addLast(hash); - } - - return true; - } - - private boolean isApiAccessExpired(final long hash) { - final long currentTime = timeSource.getCurrentTimeMillis(); - final Long lastAccess = accessMap.get(hash); - return lastAccess == null || currentTime - lastAccess >= expirationTimeInMs; - } - - private void cleanupExpiredEntries(final long currentTime) { - // Purge all expired entries. - while (!accessDeque.isEmpty()) { - final Long oldestHash = accessDeque.peekFirst(); - if (oldestHash == null) { - // Should never happen - continue; - } - - final Long lastAccessTime = accessMap.get(oldestHash); - if (lastAccessTime == null) { - // Should never happen - continue; - } - - if (currentTime - lastAccessTime < expirationTimeInMs) { - // The oldest hash is up-to-date, so stop here. - break; - } - - accessDeque.pollFirst(); - accessMap.remove(oldestHash); - } - - // If we went over capacity, remove the oldest entries until we are within the limit. - // This should never be more than 1. - final int toRemove = accessMap.size() - this.capacity; - for (int i = 0; i < toRemove; i++) { - Long oldestHash = accessDeque.pollFirst(); - if (oldestHash != null) { - accessMap.remove(oldestHash); - } - } - } - - private long computeApiHash(final String route, final String method, final int statusCode) { - long result = 17; - result = 31 * result + route.hashCode(); - result = 31 * result + method.hashCode(); - result = 31 * result + statusCode; - return result; - } -} diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java deleted file mode 100644 index af97152609e..00000000000 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java +++ /dev/null @@ -1,92 +0,0 @@ -package com.datadog.appsec.api.security; - -import com.datadog.appsec.event.EventProducerService; -import com.datadog.appsec.event.ExpiredSubscriberInfoException; -import com.datadog.appsec.event.data.DataBundle; -import com.datadog.appsec.event.data.KnownAddresses; -import com.datadog.appsec.event.data.SingletonDataBundle; -import com.datadog.appsec.gateway.AppSecRequestContext; -import com.datadog.appsec.gateway.GatewayContext; -import datadog.trace.api.gateway.RequestContext; -import datadog.trace.api.gateway.RequestContextSlot; -import datadog.trace.api.internal.TraceSegment; -import datadog.trace.bootstrap.instrumentation.api.AgentSpan; -import datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor; -import java.util.Collections; -import java.util.function.BooleanSupplier; -import javax.annotation.Nonnull; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class AppSecSpanPostProcessor implements SpanPostProcessor { - - private static final Logger log = LoggerFactory.getLogger(AppSecSpanPostProcessor.class); - private final ApiSecuritySampler sampler; - private final EventProducerService producerService; - - public AppSecSpanPostProcessor(ApiSecuritySampler sampler, EventProducerService producerService) { - this.sampler = sampler; - this.producerService = producerService; - } - - @Override - public void process(@Nonnull AgentSpan span, @Nonnull BooleanSupplier timeoutCheck) { - final RequestContext ctx_ = span.getRequestContext(); - if (ctx_ == null) { - return; - } - final AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC); - if (ctx == null) { - return; - } - - if (!ctx.isKeepOpenForApiSecurityPostProcessing()) { - return; - } - - try { - if (timeoutCheck.getAsBoolean()) { - log.debug("Timeout detected, skipping API security post-processing"); - return; - } - if (!sampler.sampleRequest(ctx)) { - log.debug("Request not sampled, skipping API security post-processing"); - return; - } - log.debug("Request sampled, processing API security post-processing"); - extractSchemas(ctx, ctx_.getTraceSegment()); - } finally { - ctx.setKeepOpenForApiSecurityPostProcessing(false); - try { - // XXX: Close the additive first. This is not strictly needed, but it'll prevent getting it - // detected as a - // missed request-ended event. - ctx.closeWafContext(); - ctx.close(); - } catch (Exception e) { - log.debug("Error closing AppSecRequestContext", e); - } - sampler.releaseOne(); - } - } - - private void extractSchemas(final AppSecRequestContext ctx, final TraceSegment traceSegment) { - final EventProducerService.DataSubscriberInfo sub = - producerService.getDataSubscribers(KnownAddresses.WAF_CONTEXT_PROCESSOR); - if (sub == null || sub.isEmpty()) { - log.debug("No subscribers for schema extraction"); - return; - } - - final DataBundle bundle = - new SingletonDataBundle<>( - KnownAddresses.WAF_CONTEXT_PROCESSOR, Collections.singletonMap("extract-schema", true)); - try { - GatewayContext gwCtx = new GatewayContext(false); - producerService.publishDataEvent(sub, ctx, bundle, gwCtx); - ctx.commitDerivatives(traceSegment); - } catch (ExpiredSubscriberInfoException e) { - log.debug("Subscriber info expired", e); - } - } -} diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java index a73de84d7d0..8b859b4904d 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java @@ -142,9 +142,6 @@ public class AppSecRequestContext implements DataBundle, Closeable { // Used to detect missing request-end event at close. private volatile boolean requestEndCalled; - private volatile boolean keepOpenForApiSecurityPostProcessing; - private volatile Long apiSecurityEndpointHash; - private static final AtomicIntegerFieldUpdater WAF_TIMEOUTS_UPDATER = AtomicIntegerFieldUpdater.newUpdater(AppSecRequestContext.class, "wafTimeouts"); private static final AtomicIntegerFieldUpdater RASP_TIMEOUTS_UPDATER = @@ -343,22 +340,6 @@ public void setRoute(String route) { this.route = route; } - public void setKeepOpenForApiSecurityPostProcessing(final boolean flag) { - this.keepOpenForApiSecurityPostProcessing = flag; - } - - public boolean isKeepOpenForApiSecurityPostProcessing() { - return this.keepOpenForApiSecurityPostProcessing; - } - - public void setApiSecurityEndpointHash(long hash) { - this.apiSecurityEndpointHash = hash; - } - - public Long getApiSecurityEndpointHash() { - return this.apiSecurityEndpointHash; - } - void addRequestHeader(String name, String value) { if (finishedRequestHeaders) { throw new IllegalStateException("Request headers were said to be finished before"); @@ -554,23 +535,18 @@ public void close() { if (!requestEndCalled) { log.debug(SEND_TELEMETRY, "Request end event was not called before close"); } - // For API Security, we sometimes keep contexts open for late processing. In that case, this - // flag needs to be - // later reset by the API Security post-processor and close must be called again. - if (!keepOpenForApiSecurityPostProcessing) { - if (wafContext != null) { - log.debug( - SEND_TELEMETRY, "WAF object had not been closed (probably missed request-end event)"); - closeWafContext(); - } - collectedCookies = null; - requestHeaders.clear(); - responseHeaders.clear(); - persistentData.clear(); - if (derivatives != null) { - derivatives.clear(); - derivatives = null; - } + if (wafContext != null) { + log.debug( + SEND_TELEMETRY, "WAF object had not been closed (probably missed request-end event)"); + closeWafContext(); + } + collectedCookies = null; + requestHeaders.clear(); + responseHeaders.clear(); + persistentData.clear(); + if (derivatives != null) { + derivatives.clear(); + derivatives = null; } } diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java index b4bdb9b64c9..8c6e3b8be15 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java @@ -7,7 +7,6 @@ import static com.datadog.appsec.gateway.AppSecRequestContext.RESPONSE_HEADERS_ALLOW_LIST; import com.datadog.appsec.AppSecSystem; -import com.datadog.appsec.api.security.ApiSecuritySampler; import com.datadog.appsec.config.TraceSegmentPostProcessor; import com.datadog.appsec.event.EventProducerService; import com.datadog.appsec.event.EventProducerService.DataSubscriberInfo; @@ -89,7 +88,7 @@ public class GatewayBridge { private final SubscriptionService subscriptionService; private final EventProducerService producerService; - private final ApiSecuritySampler requestSampler; + private final TraceSegmentPostProcessor apiSecurityPostProcessor; private final List traceSegmentPostProcessors; // subscriber cache @@ -115,11 +114,11 @@ public class GatewayBridge { public GatewayBridge( SubscriptionService subscriptionService, EventProducerService producerService, - ApiSecuritySampler requestSampler, + TraceSegmentPostProcessor apiSecurityPostProcessor, List traceSegmentPostProcessors) { this.subscriptionService = subscriptionService; this.producerService = producerService; - this.requestSampler = requestSampler; + this.apiSecurityPostProcessor = apiSecurityPostProcessor; this.traceSegmentPostProcessors = traceSegmentPostProcessors; } @@ -680,15 +679,6 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) { TraceSegment traceSeg = ctx_.getTraceSegment(); Map tags = spanInfo.getTags(); - if (maybeSampleForApiSecurity(ctx, spanInfo, tags)) { - if (!Config.get().isApmTracingEnabled()) { - traceSeg.setTagTop(Tags.ASM_KEEP, true); - traceSeg.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM); - } - } else { - ctx.closeWafContext(); - } - // AppSec report metric and events for web span only if (traceSeg != null) { traceSeg.setTagTop("_dd.appsec.enabled", 1); @@ -696,6 +686,15 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) { Collection collectedEvents = ctx.transferCollectedEvents(); + if (apiSecurityPostProcessor != null) { + final Object route = tags.get(Tags.HTTP_ROUTE); + if (route != null) { + ctx.setRoute(route.toString()); + } + // TODO: Move this to traceSegmentPostProcessors + apiSecurityPostProcessor.processTraceSegment(traceSeg, ctx, null); + } + for (TraceSegmentPostProcessor pp : this.traceSegmentPostProcessors) { pp.processTraceSegment(traceSeg, ctx, collectedEvents); } @@ -749,6 +748,7 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) { writeRequestHeaders( traceSeg, DEFAULT_REQUEST_HEADERS_ALLOW_LIST, ctx.getRequestHeaders(), false); } + // If extracted any derivatives - commit them if (!ctx.commitDerivatives(traceSeg)) { log.debug("Unable to commit, derivatives will be skipped {}", ctx.getDerivativeKeys()); @@ -766,21 +766,11 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) { ); } + ctx.closeWafContext(); ctx.close(); return NoopFlow.INSTANCE; } - private boolean maybeSampleForApiSecurity( - AppSecRequestContext ctx, IGSpanInfo spanInfo, Map tags) { - log.debug("Checking API Security for end of request handler on span: {}", spanInfo.getSpanId()); - // API Security sampling requires http.route tag. - final Object route = tags.get(Tags.HTTP_ROUTE); - if (route != null) { - ctx.setRoute(route.toString()); - } - return requestSampler.preSampleRequest(ctx); - } - private Flow onRequestHeadersDone(RequestContext ctx_) { AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC); if (ctx == null || ctx.isReqDataPublished()) { diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityProcessorTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityProcessorTest.groovy new file mode 100644 index 00000000000..c266656dcc2 --- /dev/null +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityProcessorTest.groovy @@ -0,0 +1,134 @@ +package com.datadog.appsec.api.security + +import com.datadog.appsec.event.EventProducerService +import com.datadog.appsec.event.ExpiredSubscriberInfoException +import com.datadog.appsec.event.data.KnownAddresses +import com.datadog.appsec.gateway.AppSecRequestContext +import datadog.trace.api.ProductTraceSource +import datadog.trace.api.config.AppSecConfig +import datadog.trace.api.config.GeneralConfig +import datadog.trace.api.internal.TraceSegment +import datadog.trace.bootstrap.instrumentation.api.Tags +import datadog.trace.test.util.DDSpecification + +class ApiSecurityProcessorTest extends DDSpecification { + + void 'schema extracted on happy path'() { + given: + def sampler = Mock(ApiSecuritySampler) + def producer = Mock(EventProducerService) + def subInfo = Mock(EventProducerService.DataSubscriberInfo) + def traceSegment = Mock(TraceSegment) + def ctx = Mock(AppSecRequestContext) + def processor = new ApiSecurityProcessor(sampler, producer) + + when: + processor.processTraceSegment(traceSegment, ctx, null) + + then: + noExceptionThrown() + 1 * sampler.sample(ctx) >> true + 1 * producer.getDataSubscribers(KnownAddresses.WAF_CONTEXT_PROCESSOR) >> subInfo + 1 * subInfo.isEmpty() >> false + 1 * producer.publishDataEvent(_, ctx, _, _) + 1 * traceSegment.setTagTop('asm.keep', true) + 0 * _ + } + + void 'no schema extracted if sampling is false'() { + given: + def sampler = Mock(ApiSecuritySampler) + def producer = Mock(EventProducerService) + def ctx = Mock(AppSecRequestContext) + def traceSegment = Mock(TraceSegment) + def processor = new ApiSecurityProcessor(sampler, producer) + + when: + processor.processTraceSegment(traceSegment, ctx, null) + + then: + noExceptionThrown() + 1 * sampler.sample(ctx) >> false + 0 * _ + } + + void 'process null appsec request context does nothing'() { + given: + def sampler = Mock(ApiSecuritySampler) + def producer = Mock(EventProducerService) + def traceSegment = Mock(TraceSegment) + def processor = new ApiSecurityProcessor(sampler, producer) + + when: + processor.processTraceSegment(traceSegment, null, null) + + then: + noExceptionThrown() + 0 * _ + } + + void 'empty event subscription does not break the process'() { + given: + def sampler = Mock(ApiSecuritySampler) + def producer = Mock(EventProducerService) + def subInfo = Mock(EventProducerService.DataSubscriberInfo) + def traceSegment = Mock(TraceSegment) + def ctx = Mock(AppSecRequestContext) + def processor = new ApiSecurityProcessor(sampler, producer) + + when: + processor.processTraceSegment(traceSegment, ctx, null) + + then: + noExceptionThrown() + 1 * sampler.sample(ctx) >> true + 1 * producer.getDataSubscribers(_) >> subInfo + 1 * subInfo.isEmpty() >> true + 0 * _ + } + + void 'expired event subscription does not break the process'() { + given: + def sampler = Mock(ApiSecuritySampler) + def producer = Mock(EventProducerService) + def subInfo = Mock(EventProducerService.DataSubscriberInfo) + def traceSegment = Mock(TraceSegment) + def ctx = Mock(AppSecRequestContext) + def processor = new ApiSecurityProcessor(sampler, producer) + + when: + processor.processTraceSegment(traceSegment, ctx, null) + + then: + noExceptionThrown() + 1 * sampler.sample(ctx) >> true + 1 * producer.getDataSubscribers(_) >> subInfo + 1 * subInfo.isEmpty() >> false + 1 * producer.publishDataEvent(_, ctx, _, _) >> { throw new ExpiredSubscriberInfoException() } + 0 * _ + } + + void 'test api security sampling with tracing disabled'() { + given: + injectSysConfig(GeneralConfig.APM_TRACING_ENABLED, "false") + injectSysConfig(AppSecConfig.API_SECURITY_ENABLED, "true") + def sampler = Mock(ApiSecuritySampler) + def subInfo = Mock(EventProducerService.DataSubscriberInfo) + def producer = Mock(EventProducerService) + def traceSegment = Mock(TraceSegment) + def processor = new ApiSecurityProcessor(sampler, producer) + def ctx = Mock(AppSecRequestContext) + + when: + processor.processTraceSegment(traceSegment, ctx, null) + + then: + 1 * sampler.sample(ctx) >> true + 1 * producer.getDataSubscribers(KnownAddresses.WAF_CONTEXT_PROCESSOR) >> subInfo + 1 * subInfo.isEmpty() >> false + 1 * producer.publishDataEvent(_, ctx, _, _) + 1 * traceSegment.setTagTop('asm.keep', true) + 1 * traceSegment.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM) + 0 * _ + } +} diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy index a4ef9984786..6baf4e4cf2d 100644 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecuritySamplerTest.groovy @@ -1,219 +1,68 @@ package com.datadog.appsec.api.security -import com.datadog.appsec.gateway.AppSecRequestContext -import datadog.trace.api.time.ControllableTimeSource -import datadog.trace.test.util.DDSpecification +import spock.lang.Specification -class ApiSecuritySamplerTest extends DDSpecification { +import java.util.concurrent.Executor +import java.util.concurrent.Executors - void 'happy path with single request'() { - given: - final ctx = createContext('route1', 'GET', 200) - final sampler = new ApiSecuritySamplerImpl() +class ApiSecuritySamplerTest extends Specification { - when: - final preSampled = sampler.preSampleRequest(ctx) - - then: - preSampled - - when: - ctx.setKeepOpenForApiSecurityPostProcessing(true) - final sampled = sampler.sampleRequest(ctx) - - then: - sampled - } - - void 'second request is not sampled for the same endpoint'() { - given: - AppSecRequestContext ctx1 = createContext('route1', 'GET', 200) - AppSecRequestContext ctx2 = createContext('route1', 'GET', 200) - final sampler = new ApiSecuritySamplerImpl() - - when: - final preSampled1 = sampler.preSampleRequest(ctx1) - ctx1.setKeepOpenForApiSecurityPostProcessing(true) - final sampled1 = sampler.sampleRequest(ctx1) - sampler.releaseOne() - - then: - preSampled1 - sampled1 - - when: - final preSampled2 = sampler.preSampleRequest(ctx2) - - then: - !preSampled2 - } - - void 'preSampleRequest with maximum concurrent contexts'() { - given: - final ctx1 = Spy(createContext('route2', 'GET', 200)) - final ctx2 = Spy(createContext('route3', 'GET', 200)) - final sampler = new ApiSecuritySamplerImpl() - assert sampler.MAX_POST_PROCESSING_TASKS > 0 - - when: 'exhaust the maximum number of concurrent contexts' - final List preSampled1 = (1..sampler.MAX_POST_PROCESSING_TASKS).collect { - sampler.preSampleRequest(createContext('route1', 'GET', 200 + it)) - } - - then: - preSampled1.every { it } - - and: 'try to sample one more' - final preSampled2 = sampler.preSampleRequest(ctx1) - - then: - !preSampled2 - - when: 'release one context' - sampler.releaseOne() - - and: 'next can be sampled' - final preSampled3 = sampler.preSampleRequest(ctx2) - - then: - preSampled3 + static class SamplerArgs { + int maxItemCount = 8 + int intervalSeconds = 30 + long zero = 42L + TestClock clock = new TestClock() + Executor executor = Executors.newSingleThreadExecutor() } - void 'preSampleRequest with null route'() { - given: - def ctx = createContext(null, 'GET', 200) - def sampler = new ApiSecuritySamplerImpl() - - when: - def preSampled = sampler.preSampleRequest(ctx) - - then: - !preSampled - } - - void 'preSampleRequest with null method'() { - given: - def ctx = createContext('route1', null, 200) - def sampler = new ApiSecuritySamplerImpl() - - when: - def preSampled = sampler.preSampleRequest(ctx) - - then: - !preSampled + ApiSecuritySampler buildSampler(SamplerArgs args = new SamplerArgs()) { + return new ApiSecuritySampler(args.maxItemCount, args.intervalSeconds, args.zero, args.clock, args.executor) } - void 'preSampleRequest with 0 status code'() { - given: - def ctx = createContext('route1', 'GET', 0) - def sampler = new ApiSecuritySamplerImpl() - - when: - def preSampled = sampler.preSampleRequest(ctx) - - then: - !preSampled - } - - void 'sampleRequest with null context throws'() { - given: - def sampler = new ApiSecuritySamplerImpl() - - when: - sampler.preSampleRequest(null) - - then: - thrown(NullPointerException) - } - - void 'sampleRequest without prior preSampleRequest never works'() { - given: - def sampler = new ApiSecuritySamplerImpl() - def ctx = createContext('route1', 'GET', 200) - - when: - def sampled = sampler.sampleRequest(ctx) - - then: - !sampled - } - - void 'sampleRequest honors expiration'() { - given: - def ctx = createContext('route1', 'GET', 200) - ctx.setApiSecurityEndpointHash(42L) - ctx.setKeepOpenForApiSecurityPostProcessing(true) - final timeSource = new ControllableTimeSource() - timeSource.set(0) - final long expirationTimeInMs = 10L - final long expirationTimeInNs = expirationTimeInMs * 1_000_000 - def sampler = new ApiSecuritySamplerImpl(10, expirationTimeInMs, timeSource) - - when: - def sampled = sampler.sampleRequest(ctx) - - then: - sampled - - when: - sampled = sampler.sampleRequest(ctx) - - then: 'second request is not sampled' - !sampled - - when: 'expiration time has passed' - timeSource.advance(expirationTimeInNs) - sampled = sampler.sampleRequest(ctx) - - then: 'request is sampled again' - sampled - } - - void 'internal accessMap never goes beyond capacity'() { - given: - final timeSource = new ControllableTimeSource() - timeSource.set(0) - final long expirationTimeInMs = 10_000 - final int maxCapacity = 10 - ApiSecuritySamplerImpl sampler = new ApiSecuritySamplerImpl(10, expirationTimeInMs, timeSource) + void 'test single entry and no concurrency'() { + setup: + int intervalSeconds = 30 + TestClock clock = new TestClock() + ApiSecuritySampler sampler = buildSampler(new SamplerArgs(intervalSeconds: intervalSeconds, clock: clock)) expect: - for (int i = 0; i < maxCapacity * 10; i++) { - timeSource.advance(1_000_000) - final ctx = createContext('route1', 'GET', 200 + 1) - ctx.setApiSecurityEndpointHash(i as long) - ctx.setKeepOpenForApiSecurityPostProcessing(true) - assert sampler.sampleRequest(ctx) - assert sampler.accessMap.size() <= maxCapacity - } + sampler.sample(1L) + !sampler.sample(1L) + clock.inc(1) + !sampler.sample(1L) + // Increment time to just one second before the next interval + clock.inc(intervalSeconds - 2) + !sampler.sample(1L) + // Increment time to the next interval (exactly) + clock.inc(1) + sampler.sample(1L) + !sampler.sample(1L) } - void 'expired entries are purged from internal accessMap'() { - given: - final timeSource = new ControllableTimeSource() - timeSource.set(0) - final long expirationTimeInMs = 10_000 - final int maxCapacity = 10 - ApiSecuritySamplerImpl sampler = new ApiSecuritySamplerImpl(10, expirationTimeInMs, timeSource) + void 'test full map and no concurrency and no rebuilds'() { + setup: + int maxItemCount = 8 + int intervalSeconds = 30 + TestClock clock = new TestClock() + // Inhibit map rebuilding + Executor executor = Stub(Executor) + ApiSecuritySampler sampler = buildSampler(new SamplerArgs(maxItemCount: maxItemCount, intervalSeconds: intervalSeconds, clock: clock, executor: executor)) expect: - for (int i = 0; i < maxCapacity * 10; i++) { - final ctx = createContext('route1', 'GET', 200 + 1) - ctx.setApiSecurityEndpointHash(i as long) - ctx.setKeepOpenForApiSecurityPostProcessing(true) - assert sampler.sampleRequest(ctx) - assert sampler.accessMap.size() <= 2 - if (i % 2) { - timeSource.advance(expirationTimeInMs * 1_000_000) - } + for (int i = 0; i < maxItemCount * 2; i++) { + assert sampler.sample(i) } - } - - private static AppSecRequestContext createContext(final String route, final String method, int statusCode) { - final AppSecRequestContext ctx = new AppSecRequestContext() - ctx.setRoute(route) - ctx.setMethod(method) - ctx.setResponseStatus(statusCode) - ctx + for (int i = 0; i < maxItemCount * 2; i++) { + assert !sampler.sample(i) + } + assert !sampler.sample(Long.MAX_VALUE) + clock.inc(intervalSeconds) + for (int i = 0; i < maxItemCount * 2; i++) { + assert sampler.sample(i) + } + for (int i = 0; i < maxItemCount * 2; i++) { + assert !sampler.sample(i) + } + assert !sampler.sample(Long.MAX_VALUE) } } diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/AppSecSpanPostProcessorTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/AppSecSpanPostProcessorTest.groovy deleted file mode 100644 index 321f3876d94..00000000000 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/AppSecSpanPostProcessorTest.groovy +++ /dev/null @@ -1,251 +0,0 @@ -package com.datadog.appsec.api.security - -import com.datadog.appsec.event.EventProducerService -import com.datadog.appsec.event.ExpiredSubscriberInfoException -import com.datadog.appsec.event.data.KnownAddresses -import com.datadog.appsec.gateway.AppSecRequestContext -import datadog.trace.api.gateway.RequestContext -import datadog.trace.api.internal.TraceSegment -import datadog.trace.bootstrap.instrumentation.api.AgentSpan -import datadog.trace.test.util.DDSpecification - -class AppSecSpanPostProcessorTest extends DDSpecification { - - void 'schema extracted on happy path'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def subInfo = Mock(EventProducerService.DataSubscriberInfo) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def traceSegment = Mock(TraceSegment) - def ctx = Mock(AppSecRequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> ctx - 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true - 1 * sampler.sampleRequest(_) >> true - 1 * reqCtx.getTraceSegment() >> traceSegment - 1 * producer.getDataSubscribers(KnownAddresses.WAF_CONTEXT_PROCESSOR) >> subInfo - 1 * subInfo.isEmpty() >> false - 1 * producer.publishDataEvent(_, ctx, _, _) - 1 * ctx.commitDerivatives(traceSegment) - 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) - 1 * ctx.closeWafContext() - 1 * ctx.close() - 1 * sampler.releaseOne() - 0 * _ - } - - void 'no schema extracted if sampling is false'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def ctx = Mock(AppSecRequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> ctx - 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true - 1 * sampler.sampleRequest(_) >> false - 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) - 1 * ctx.closeWafContext() - 1 * ctx.close() - 1 * sampler.releaseOne() - 0 * _ - } - - void 'permit is released even if request context close throws'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def traceSegment = Mock(TraceSegment) - def ctx = Mock(AppSecRequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> ctx - 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true - 1 * sampler.sampleRequest(_) >> true - 1 * reqCtx.getTraceSegment() >> traceSegment - 1 * producer.getDataSubscribers(_) >> null - 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) - 1 * ctx.closeWafContext() - 1 * ctx.close() >> { throw new RuntimeException() } - 1 * sampler.releaseOne() - 0 * _ - } - - void 'context is cleaned up on timeout'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def ctx = Mock(AppSecRequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { true }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> ctx - 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true - 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) - 1 * ctx.closeWafContext() - 1 * ctx.close() - 1 * sampler.releaseOne() - 0 * _ - } - - void 'process null request context does nothing'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def span = Mock(AgentSpan) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> null - 0 * _ - } - - void 'process null appsec request context does nothing'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> null - 0 * _ - } - - void 'process already closed context does nothing'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def ctx = Mock(AppSecRequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> ctx - 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> false - 0 * _ - } - - void 'process throws on null span'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(null, { false }) - - then: - thrown(NullPointerException) - 0 * _ - } - - void 'empty event subscription does not break the process'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def subInfo = Mock(EventProducerService.DataSubscriberInfo) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def traceSegment = Mock(TraceSegment) - def ctx = Mock(AppSecRequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> ctx - 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true - 1 * sampler.sampleRequest(_) >> true - 1 * reqCtx.getTraceSegment() >> traceSegment - 1 * producer.getDataSubscribers(_) >> subInfo - 1 * subInfo.isEmpty() >> true - 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) - 1 * ctx.closeWafContext() - 1 * ctx.close() - 1 * sampler.releaseOne() - 0 * _ - } - - void 'expired event subscription does not break the process'() { - given: - def sampler = Mock(ApiSecuritySamplerImpl) - def producer = Mock(EventProducerService) - def subInfo = Mock(EventProducerService.DataSubscriberInfo) - def span = Mock(AgentSpan) - def reqCtx = Mock(RequestContext) - def traceSegment = Mock(TraceSegment) - def ctx = Mock(AppSecRequestContext) - def processor = new AppSecSpanPostProcessor(sampler, producer) - - when: - processor.process(span, { false }) - - then: - noExceptionThrown() - 1 * span.getRequestContext() >> reqCtx - 1 * reqCtx.getData(_) >> ctx - 1 * ctx.isKeepOpenForApiSecurityPostProcessing() >> true - 1 * sampler.sampleRequest(_) >> true - 1 * reqCtx.getTraceSegment() >> traceSegment - 1 * producer.getDataSubscribers(_) >> subInfo - 1 * subInfo.isEmpty() >> false - 1 * producer.publishDataEvent(_, ctx, _, _) >> { throw new ExpiredSubscriberInfoException() } - 1 * ctx.setKeepOpenForApiSecurityPostProcessing(false) - 1 * ctx.closeWafContext() - 1 * ctx.close() - 1 * sampler.releaseOne() - 0 * _ - } -} diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/TestClock.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/TestClock.groovy new file mode 100644 index 00000000000..b8d1ae6ddf1 --- /dev/null +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/TestClock.groovy @@ -0,0 +1,20 @@ +package com.datadog.appsec.api.security + +import groovy.transform.CompileStatic + +@CompileStatic +class TestClock implements ApiSecuritySampler.MonotonicClock { + + private int time = 0 + + int inc(final int delta) { + assert delta >= 0 : "Delta must be non-negative" + time += delta + return time + } + + @Override + int now() { + return time + } +} diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy index 6839b7061b9..1dfd5c9102b 100644 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy @@ -1,7 +1,7 @@ package com.datadog.appsec.gateway import com.datadog.appsec.AppSecSystem -import com.datadog.appsec.api.security.ApiSecuritySamplerImpl +import com.datadog.appsec.api.security.ApiSecurityProcessor import com.datadog.appsec.config.TraceSegmentPostProcessor import com.datadog.appsec.event.EventDispatcher import com.datadog.appsec.event.EventProducerService @@ -10,8 +10,6 @@ import com.datadog.appsec.event.data.KnownAddresses import com.datadog.appsec.report.AppSecEvent import com.datadog.appsec.report.AppSecEventWrapper import datadog.trace.api.ProductTraceSource -import datadog.trace.api.config.GeneralConfig -import static datadog.trace.api.config.IastConfig.IAST_DEDUPLICATION_ENABLED import datadog.trace.api.function.TriConsumer import datadog.trace.api.function.TriFunction import datadog.trace.api.gateway.BlockResponseFunction @@ -85,8 +83,8 @@ class GatewayBridgeSpecification extends DDSpecification { } TraceSegmentPostProcessor pp = Mock() - ApiSecuritySamplerImpl requestSampler = Mock(ApiSecuritySamplerImpl) - GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, requestSampler, [pp]) + ApiSecurityProcessor apiSecurityProcessor = Mock(ApiSecurityProcessor) + GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, apiSecurityProcessor, [pp]) Supplier> requestStartedCB BiFunction> requestEndedCB @@ -1180,7 +1178,7 @@ class GatewayBridgeSpecification extends DDSpecification { then: 1 * mockAppSecCtx.transferCollectedEvents() >> [] 1 * spanInfo.getTags() >> ['http.route': 'route'] - 1 * requestSampler.preSampleRequest(_) >> true + 1 * apiSecurityProcessor.processTraceSegment(_, _, _ ) 0 * traceSegment.setTagTop(Tags.ASM_KEEP, true) 0 * traceSegment.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM) } @@ -1198,31 +1196,11 @@ class GatewayBridgeSpecification extends DDSpecification { then: 1 * mockAppSecCtx.transferCollectedEvents() >> [] 1 * spanInfo.getTags() >> ['http.route': 'route'] - 1 * requestSampler.preSampleRequest(_) >> false + 1 * apiSecurityProcessor.processTraceSegment(_, _, _ ) 0 * traceSegment.setTagTop(Tags.ASM_KEEP, true) 0 * traceSegment.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM) } - void 'test api security sampling with tracing disabled'() { - given: - injectSysConfig(GeneralConfig.APM_TRACING_ENABLED, "false") - AppSecRequestContext mockAppSecCtx = Mock(AppSecRequestContext) - RequestContext mockCtx = Stub(RequestContext) { - getData(RequestContextSlot.APPSEC) >> mockAppSecCtx - getTraceSegment() >> traceSegment - } - IGSpanInfo spanInfo = Mock(AgentSpan) - when: - def flow = requestEndedCB.apply(mockCtx, spanInfo) - then: - 1 * mockAppSecCtx.transferCollectedEvents() >> [] - 1 * spanInfo.getTags() >> ['http.route': 'route'] - 1 * requestSampler.preSampleRequest(_) >> true - 1 * traceSegment.setTagTop(Tags.ASM_KEEP, true) - 1 * traceSegment.setTagTop(Tags.PROPAGATED_TRACE_SOURCE, ProductTraceSource.ASM) - } - - void 'test default writeRequestHeaders'(){ given: def allowedHeaders = ['x-allowed-header', 'x-multiple-allowed-header', 'x-always-included'] as Set diff --git a/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java b/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java index 6cd0ecdaed2..39acc558301 100644 --- a/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java +++ b/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java @@ -6,8 +6,6 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import datadog.communication.ddagent.DroppingPolicy; -import datadog.trace.api.Config; -import datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor; import datadog.trace.common.sampling.SingleSpanSampler; import datadog.trace.common.writer.ddagent.FlushEvent; import datadog.trace.common.writer.ddagent.Prioritization; @@ -18,7 +16,6 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.function.BooleanSupplier; import org.jctools.queues.MessagePassingQueue; import org.jctools.queues.MpscBlockingConsumerArrayQueue; import org.slf4j.Logger; @@ -184,7 +181,6 @@ public void onEvent(Object event) { try { if (event instanceof List) { List trace = (List) event; - maybeTracePostProcessing(trace); // TODO populate `_sample_rate` metric in a way that accounts for lost/dropped traces payloadDispatcher.addTrace(trace); } else if (event instanceof FlushEvent) { @@ -245,33 +241,5 @@ private void consumeBatch(MessagePassingQueue queue) { protected boolean queuesAreEmpty() { return primaryQueue.isEmpty() && secondaryQueue.isEmpty(); } - - private void maybeTracePostProcessing(List trace) { - if (trace == null || trace.isEmpty()) { - return; - } - - final SpanPostProcessor postProcessor = SpanPostProcessor.Holder.INSTANCE; - try { - final long timeout = Config.get().getTracePostProcessingTimeout(); - final long deadline = System.currentTimeMillis() + timeout; - final boolean[] timedOut = {false}; - final BooleanSupplier timeoutCheck = - () -> { - if (timedOut[0]) { - return true; - } - if (System.currentTimeMillis() > deadline) { - timedOut[0] = true; - } - return timedOut[0]; - }; - for (DDSpan span : trace) { - postProcessor.process(span, timeoutCheck); - } - } catch (Throwable e) { - log.debug("Error while trace post-processing", e); - } - } } } diff --git a/dd-trace-core/src/test/groovy/datadog/trace/common/writer/TraceProcessingWorkerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/common/writer/TraceProcessingWorkerTest.groovy index 2a37a69ec64..dd823a587be 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/common/writer/TraceProcessingWorkerTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/common/writer/TraceProcessingWorkerTest.groovy @@ -4,10 +4,7 @@ import datadog.trace.common.sampling.SingleSpanSampler import datadog.trace.common.writer.ddagent.PrioritizationStrategy.PublishResult import datadog.trace.core.CoreSpan import datadog.trace.core.DDSpan -import datadog.trace.core.DDSpanContext -import datadog.trace.core.PendingTrace import datadog.trace.core.monitor.HealthMetrics -import datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor import datadog.trace.test.util.DDSpecification import spock.util.concurrent.PollingConditions @@ -152,56 +149,6 @@ class TraceProcessingWorkerTest extends DDSpecification { priority << [SAMPLER_DROP, USER_DROP, SAMPLER_KEEP, USER_KEEP, UNSET] } - def "trace should be post-processed"() { - setup: - AtomicInteger acceptedCount = new AtomicInteger() - PayloadDispatcherImpl countingDispatcher = Mock(PayloadDispatcherImpl) - countingDispatcher.addTrace(_) >> { - acceptedCount.getAndIncrement() - } - HealthMetrics healthMetrics = Mock(HealthMetrics) - - def span1 = DDSpan.create("test", 0, Mock(DDSpanContext) { - getTraceCollector() >> Mock(PendingTrace) { - getCurrentTimeNano() >> 0 - } - }, []) - def processedSpan1 = false - - // Span 2 - should NOT be post-processed - def span2 = DDSpan.create("test", 0, Mock(DDSpanContext) { - getTraceCollector() >> Mock(PendingTrace) { - getCurrentTimeNano() >> 0 - } - }, []) - def processedSpan2 = false - - SpanPostProcessor.Holder.INSTANCE = Mock(SpanPostProcessor) { - process(span1, _) >> { processedSpan1 = true } - process(span2, _) >> { processedSpan2 = true } - } - - TraceProcessingWorker worker = new TraceProcessingWorker(10, healthMetrics, - countingDispatcher, { - false - }, FAST_LANE, 100, TimeUnit.SECONDS, null) - worker.start() - - when: "traces are submitted" - worker.publish(span1, SAMPLER_KEEP, [span1, span2]) - worker.publish(span2, SAMPLER_KEEP, [span1, span2]) - - then: "traces are passed through unless rejected on submission" - conditions.eventually { - assert processedSpan1 - assert processedSpan2 - } - - cleanup: - SpanPostProcessor.Holder.INSTANCE = SpanPostProcessor.Holder.NOOP - worker.close() - } - def "traces should be processed"() { setup: AtomicInteger acceptedCount = new AtomicInteger() diff --git a/internal-api/build.gradle b/internal-api/build.gradle index fbf1916a4af..d395cd5eb87 100644 --- a/internal-api/build.gradle +++ b/internal-api/build.gradle @@ -220,10 +220,7 @@ excludedClassesCoverage += [ 'datadog.trace.util.stacktrace.StackTraceFrame', 'datadog.trace.api.iast.VulnerabilityMarks', 'datadog.trace.api.iast.securitycontrol.SecurityControlHelper', - 'datadog.trace.api.iast.securitycontrol.SecurityControl', - // Trivial holder and no-op - 'datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor.Holder', - 'datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor.NoOpSpanPostProcessor', + 'datadog.trace.api.iast.securitycontrol.SecurityControl' ] excludedClassesBranchCoverage = [ 'datadog.trace.api.ProductActivationConfig', diff --git a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/SpanPostProcessor.java b/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/SpanPostProcessor.java deleted file mode 100644 index 137ddedba1d..00000000000 --- a/internal-api/src/main/java/datadog/trace/bootstrap/instrumentation/api/SpanPostProcessor.java +++ /dev/null @@ -1,37 +0,0 @@ -package datadog.trace.bootstrap.instrumentation.api; - -import java.util.function.BooleanSupplier; -import javax.annotation.Nonnull; - -/** - * Applies post-processing of spans before serialization. - * - *

Post-processing runs in TraceProcessingWorker thread. This provides the following properties: - *

  • Runs in a single thread. Post-processing for each span runs sequentially. - *
  • Runs after the request end, and does not block the main thread. - *
  • Runs at a point where the sampler decision is already available. - */ -public interface SpanPostProcessor { - - /** - * Post-processes a span, if needed. - * - *

    Implementations should use {@code timeoutCheck}, and if true, they should halt processing as - * much as possible. This method is guaranteed to be called even if post-processing of previous - * spans have timed out. - */ - void process(@Nonnull AgentSpan span, @Nonnull BooleanSupplier timeoutCheck); - - class Holder { - public static final SpanPostProcessor NOOP = new NoOpSpanPostProcessor(); - - // XXX: At the moment, a single post-processor can be registered, and only AppSec defines one. - // If other products add their own, we'll need to refactor this to support multiple processors. - public static volatile SpanPostProcessor INSTANCE = NOOP; - } - - class NoOpSpanPostProcessor implements SpanPostProcessor { - @Override - public void process(@Nonnull AgentSpan span, @Nonnull BooleanSupplier timeoutCheck) {} - } -}