Skip to content

Add FilteringAdapter in SQS #1388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,20 @@
package io.awspring.cloud.sqs.config;

import io.awspring.cloud.sqs.ConfigUtils;
import io.awspring.cloud.sqs.listener.AbstractMessageListenerContainer;
import io.awspring.cloud.sqs.listener.AsyncComponentAdapters;
import io.awspring.cloud.sqs.listener.AsyncMessageListener;
import io.awspring.cloud.sqs.listener.ContainerComponentFactory;
import io.awspring.cloud.sqs.listener.ContainerOptions;
import io.awspring.cloud.sqs.listener.ContainerOptionsBuilder;
import io.awspring.cloud.sqs.listener.MessageListener;
import io.awspring.cloud.sqs.listener.MessageListenerContainer;
import io.awspring.cloud.sqs.listener.*;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler;
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.function.Consumer;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

/**
* Base implementation for a {@link MessageListenerContainerFactory}. Contains the components and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,24 @@

import io.awspring.cloud.sqs.ConfigUtils;
import io.awspring.cloud.sqs.annotation.SqsListener;
import io.awspring.cloud.sqs.listener.AsyncMessageListener;
import io.awspring.cloud.sqs.listener.ContainerComponentFactory;
import io.awspring.cloud.sqs.listener.ContainerOptions;
import io.awspring.cloud.sqs.listener.MessageListener;
import io.awspring.cloud.sqs.listener.SqsContainerOptions;
import io.awspring.cloud.sqs.listener.SqsContainerOptionsBuilder;
import io.awspring.cloud.sqs.listener.SqsMessageListenerContainer;
import io.awspring.cloud.sqs.listener.*;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.acknowledgement.AsyncAcknowledgementResultCallback;
import io.awspring.cloud.sqs.listener.errorhandler.AsyncErrorHandler;
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;

import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
* {@link MessageListenerContainerFactory} implementation for creating {@link SqsMessageListenerContainer} instances. A
* factory can be assigned to a {@link io.awspring.cloud.sqs.annotation.SqsListener @SqsListener} by using the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter;
import java.time.Duration;

import io.awspring.cloud.sqs.support.filter.DefaultMessageFilter;
import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.retry.backoff.BackOffPolicy;
Expand Down Expand Up @@ -59,6 +62,8 @@ public abstract class AbstractContainerOptions<O extends ContainerOptions<O, B>,

private final AcknowledgementMode acknowledgementMode;

private final MessageFilter<?> messageFilter;

@Nullable
private final AcknowledgementOrdering acknowledgementOrdering;

Expand Down Expand Up @@ -92,6 +97,8 @@ protected AbstractContainerOptions(Builder<?, ?> builder) {
this.acknowledgementThreshold = builder.acknowledgementThreshold;
this.componentsTaskExecutor = builder.componentsTaskExecutor;
this.acknowledgementResultTaskExecutor = builder.acknowledgementResultTaskExecutor;
this.messageFilter = builder.messageFilter;

Assert.isTrue(this.maxMessagesPerPoll <= this.maxConcurrentMessages, String.format(
"messagesPerPoll should be less than or equal to maxConcurrentMessages. Values provided: %s and %s respectively",
this.maxMessagesPerPoll, this.maxConcurrentMessages));
Expand Down Expand Up @@ -164,6 +171,11 @@ public MessagingMessageConverter<?> getMessageConverter() {
return this.messageConverter;
}

@Override
public MessageFilter<?> getMessageFilter() {
return this.messageFilter;
}

@Nullable
@Override
public Duration getAcknowledgementInterval() {
Expand Down Expand Up @@ -244,6 +256,8 @@ protected abstract static class Builder<B extends ContainerOptionsBuilder<B, O>,

private AcknowledgementMode acknowledgementMode = DEFAULT_ACKNOWLEDGEMENT_MODE;

private MessageFilter<?> messageFilter = new DefaultMessageFilter<>();

@Nullable
private AcknowledgementOrdering acknowledgementOrdering;

Expand Down Expand Up @@ -280,6 +294,7 @@ protected Builder(AbstractContainerOptions<?, ?> options) {
this.acknowledgementThreshold = options.acknowledgementThreshold;
this.componentsTaskExecutor = options.componentsTaskExecutor;
this.acknowledgementResultTaskExecutor = options.acknowledgementResultTaskExecutor;
this.messageFilter = options.messageFilter;
}

@Override
Expand Down Expand Up @@ -400,6 +415,13 @@ public B messageConverter(MessagingMessageConverter<?> messageConverter) {
return self();
}

@Override
public B messageFilter(MessageFilter<?> messageFilter) {
Assert.notNull(messageFilter, "messageFilter cannot be null");
this.messageFilter = messageFilter;
return self();
}

@SuppressWarnings("unchecked")
private B self() {
return (B) this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@
import io.awspring.cloud.sqs.listener.errorhandler.ErrorHandler;
import io.awspring.cloud.sqs.listener.interceptor.AsyncMessageInterceptor;
import io.awspring.cloud.sqs.listener.interceptor.MessageInterceptor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.function.Consumer;
import io.awspring.cloud.sqs.support.filter.DefaultMessageFilter;
import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.function.Consumer;

/**
* Base implementation for {@link MessageListenerContainer} with {@link SmartLifecycle} and component management
* capabilities.
Expand Down Expand Up @@ -132,7 +135,11 @@ public void addMessageInterceptor(AsyncMessageInterceptor<T> messageInterceptor)
@Override
public void setMessageListener(MessageListener<T> messageListener) {
Assert.notNull(messageListener, "messageListener cannot be null");
this.messageListener = AsyncComponentAdapters.adapt(messageListener);
if (containerOptions.getMessageFilter() instanceof DefaultMessageFilter) {
this.messageListener = AsyncComponentAdapters.adapt(messageListener);
} else {
this.messageListener = AsyncComponentAdapters.adaptFilter(messageListener, (MessageFilter<T>) containerOptions.getMessageFilter());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Supplier;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.task.TaskExecutor;
Expand Down Expand Up @@ -76,6 +78,19 @@ public static <T> AsyncMessageListener<T> adapt(MessageListener<T> messageListen
return new BlockingMessageListenerAdapter<>(messageListener);
}

/**
* Adapt the provided {@link MessageListener} and {@link MessageFilter} into a single {@link AsyncMessageListener}
* that only forwards messages passing the filter.
*
* @param messageListener the message listener to be adapted
* @param messageFilter the filter used to evaluate incoming messages
* @param <T> the message payload type
* @return the adapted and filtered async message listener
*/
public static <T> AsyncMessageListener<T> adaptFilter(MessageListener<T> messageListener, MessageFilter<T> messageFilter) {
return new FilteredMessageListenerAdapter<>(messageListener, messageFilter);
}

public static <T> AsyncAcknowledgementResultCallback<T> adapt(
AcknowledgementResultCallback<T> acknowledgementResultCallback) {
return new BlockingAcknowledgementResultCallbackAdapter<>(acknowledgementResultCallback);
Expand Down Expand Up @@ -214,6 +229,42 @@ public CompletableFuture<Void> onMessage(Collection<Message<T>> messages) {
}
}

private static class FilteredMessageListenerAdapter<T> extends AbstractThreadingComponentAdapter
implements AsyncMessageListener<T> {

private final MessageListener<T> filteredMessageListener;
private final MessageFilter<T> filter;

public FilteredMessageListenerAdapter(MessageListener<T> filteredMessageListener, MessageFilter<T> filter) {
this.filteredMessageListener = filteredMessageListener;
this.filter = filter;
}

@Override
public CompletableFuture<Void> onMessage(Message<T> message) {
if (filter.process(message)) {
return execute(() -> this.filteredMessageListener.onMessage(message));
}
else {
return CompletableFuture.completedFuture(null);
}
}

@Override
public CompletableFuture<Void> onMessage(Collection<Message<T>> messages) {
Collection<Message<T>> filteredMessages = messages.stream()
.filter(filter::process)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat pick again but please switch to for. Since size will be always maximum 10 lets switch to for loop to take benefit of performance.

.toList();

if (filteredMessages.isEmpty()) {
return CompletableFuture.completedFuture(null);
}

return execute(() -> this.filteredMessageListener.onMessage(filteredMessages));
}
}


private static class BlockingErrorHandlerAdapter<T> extends AbstractThreadingComponentAdapter
implements AsyncErrorHandler<T> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import java.time.Duration;
import java.util.Collection;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.retry.backoff.BackOffPolicy;
Expand Down Expand Up @@ -139,6 +141,11 @@ default BackOffPolicy getPollBackOffPolicy() {
*/
MessagingMessageConverter<?> getMessageConverter();

/** Return the message filter applied before message processing.
* @return the message filter.
*/
MessageFilter<?> getMessageFilter();

/**
* Return the maximum interval between acknowledgements for batch acknowledgements.
* @return the interval.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.awspring.cloud.sqs.listener.acknowledgement.handler.AcknowledgementMode;
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import java.time.Duration;

import io.awspring.cloud.sqs.support.filter.MessageFilter;
import org.springframework.core.task.TaskExecutor;
import org.springframework.retry.backoff.BackOffPolicy;

Expand Down Expand Up @@ -187,6 +189,14 @@ default B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) {
*/
B messageConverter(MessagingMessageConverter<?> messageConverter);

/**
* Set the {@link MessagingMessageConverter} for this container.
*
* @param messageFilter the message filter.
* @return this instance.
*/
B messageFilter(MessageFilter<?> messageFilter);

/**
* Create the {@link ContainerOptions} instance.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.awspring.cloud.sqs.support.filter;

import org.springframework.messaging.Message;

public class DefaultMessageFilter<T> implements MessageFilter<T> {
@Override
public boolean process(Message<T> message) {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.awspring.cloud.sqs.support.filter;

import org.springframework.messaging.Message;

public interface MessageFilter<T> {
boolean process(Message<T> message);
}
Loading