diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java index 92585ff9a4..fc0fa9e7e8 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java @@ -936,6 +936,16 @@ else if (listener instanceof MessageListener) { this.isConsumerAwareListener = listenerType.equals(ListenerType.ACKNOWLEDGING_CONSUMER_AWARE) || listenerType.equals(ListenerType.CONSUMER_AWARE); this.commonErrorHandler = determineCommonErrorHandler(); + // Setup async failure callback for suspend functions when CommonErrorHandler is explicitly configured + if (getCommonErrorHandler() != null && this.listener != null) { + MessageListener target = unwrapDelegateIfAny(this.listener); + if (target instanceof RecordMessagingMessageListenerAdapter) { + @SuppressWarnings("unchecked") + RecordMessagingMessageListenerAdapter adapter = + (RecordMessagingMessageListenerAdapter) target; + adapter.setCallbackForAsyncFailure(this::callbackForAsyncFailure); + } + } Assert.state(!this.isBatchListener || !this.isRecordAck, "Cannot use AckMode.RECORD with a batch listener"); if (this.containerProperties.getScheduler() != null) { diff --git a/spring-kafka/src/test/kotlin/org/springframework/kafka/listener/EnableKafkaKotlinCoroutinesTests.kt b/spring-kafka/src/test/kotlin/org/springframework/kafka/listener/EnableKafkaKotlinCoroutinesTests.kt index 3033da54e2..8735e595f3 100644 --- a/spring-kafka/src/test/kotlin/org/springframework/kafka/listener/EnableKafkaKotlinCoroutinesTests.kt +++ b/spring-kafka/src/test/kotlin/org/springframework/kafka/listener/EnableKafkaKotlinCoroutinesTests.kt @@ -35,6 +35,7 @@ import org.springframework.kafka.core.DefaultKafkaConsumerFactory import org.springframework.kafka.core.DefaultKafkaProducerFactory import org.springframework.kafka.core.KafkaTemplate import org.springframework.kafka.core.ProducerFactory +import org.springframework.kafka.listener.DefaultErrorHandler import org.springframework.kafka.listener.KafkaListenerErrorHandler import org.springframework.kafka.support.Acknowledgment import org.springframework.kafka.test.EmbeddedKafkaBroker @@ -59,7 +60,8 @@ import java.util.concurrent.TimeUnit @SpringJUnitConfig @DirtiesContext @EmbeddedKafka(topics = ["kotlinAsyncTestTopic1", "kotlinAsyncTestTopic2", - "kotlinAsyncBatchTestTopic1", "kotlinAsyncBatchTestTopic2", "kotlinReplyTopic1"], partitions = 1) + "kotlinAsyncBatchTestTopic1", "kotlinAsyncBatchTestTopic2", "kotlinReplyTopic1", + "kotlinAsyncTestTopicCommonHandler"], partitions = 1) class EnableKafkaKotlinCoroutinesTests { @Autowired @@ -108,6 +110,13 @@ class EnableKafkaKotlinCoroutinesTests { assertThat(cr?.value() ?: "null").isEqualTo("FOO") } + @Test + fun `test suspend function with CommonErrorHandler`() { + this.template.send("kotlinAsyncTestTopicCommonHandler", "fail") + assertThat(this.config.commonHandlerLatch.await(10, TimeUnit.SECONDS)).isTrue() + assertThat(this.config.commonHandlerInvoked).isTrue() + } + @KafkaListener(id = "sendTopic", topics = ["kotlinAsyncTestTopic3"], containerFactory = "kafkaListenerContainerFactory") class Listener { @@ -138,6 +147,9 @@ class EnableKafkaKotlinCoroutinesTests { @Volatile var batchError: Boolean = false + @Volatile + var commonHandlerInvoked: Boolean = false + val latch1 = CountDownLatch(1) val latch2 = CountDownLatch(1) @@ -146,6 +158,8 @@ class EnableKafkaKotlinCoroutinesTests { val batchLatch2 = CountDownLatch(1) + val commonHandlerLatch = CountDownLatch(1) + @Value("\${" + EmbeddedKafkaBroker.SPRING_EMBEDDED_KAFKA_BROKERS + "}") private lateinit var brokerAddresses: String @@ -217,6 +231,23 @@ class EnableKafkaKotlinCoroutinesTests { return factory } + @Bean + fun commonErrorHandler(): DefaultErrorHandler { + return DefaultErrorHandler { record, exception -> + commonHandlerInvoked = true + commonHandlerLatch.countDown() + } + } + + @Bean + fun kafkaListenerContainerFactoryWithCommonHandler(): ConcurrentKafkaListenerContainerFactory { + val factory: ConcurrentKafkaListenerContainerFactory + = ConcurrentKafkaListenerContainerFactory() + factory.setConsumerFactory(kcf()) + factory.setCommonErrorHandler(commonErrorHandler()) + return factory + } + @KafkaListener(id = "kotlin", topics = ["kotlinAsyncTestTopic1"], containerFactory = "kafkaListenerContainerFactory") suspend fun listen(value: String, acknowledgment: Acknowledgment) { @@ -247,6 +278,14 @@ class EnableKafkaKotlinCoroutinesTests { } } + @KafkaListener(id = "kotlin-common-handler", topics = ["kotlinAsyncTestTopicCommonHandler"], + containerFactory = "kafkaListenerContainerFactoryWithCommonHandler") + suspend fun listenWithCommonHandler(value: String) { + if (value == "fail") { + throw RuntimeException("Test exception for CommonErrorHandler") + } + } + } }