Skip to content

Commit bab03c1

Browse files
jinwoobaejinwoobae
authored andcommitted
fix: Handle TopicPartition serialization in KafkaItemReader
- Convert TopicPartition keys to String format (topic-partition) when saving to ExecutionContext - Support both String and TopicPartition keys when restoring from ExecutionContext - Prevent ClassCastException when using JobRepository with Jackson serialization Resolves spring-projects#3797
1 parent 08c4cb1 commit bab03c1

File tree

2 files changed

+207
-12
lines changed

2 files changed

+207
-12
lines changed

spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemReader.java

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.util.Map;
2626
import java.util.Properties;
2727

28+
import org.apache.commons.logging.Log;
29+
import org.apache.commons.logging.LogFactory;
2830
import org.apache.kafka.clients.consumer.ConsumerConfig;
2931
import org.apache.kafka.clients.consumer.ConsumerRecord;
3032
import org.apache.kafka.clients.consumer.KafkaConsumer;
@@ -52,7 +54,15 @@
5254
*/
5355
public class KafkaItemReader<K, V> extends AbstractItemStreamItemReader<V> {
5456

57+
private static final Log log = LogFactory.getLog(KafkaItemReader.class);
58+
5559
private static final String TOPIC_PARTITION_OFFSETS = "topic.partition.offsets";
60+
61+
private static final String KEY_TOPIC = "topic";
62+
63+
private static final String KEY_PARTITION = "partition";
64+
65+
private static final String KEY_OFFSET = "offset";
5666

5767
private static final long DEFAULT_POLL_TIMEOUT = 30L;
5868

@@ -167,21 +177,50 @@ public void setPartitionOffsets(Map<TopicPartition, Long> partitionOffsets) {
167177
@Override
168178
public void open(ExecutionContext executionContext) {
169179
this.kafkaConsumer = new KafkaConsumer<>(this.consumerProperties);
180+
initializePartitionOffsets();
181+
182+
if (this.saveState && executionContext.containsKey(getExecutionContextKey(TOPIC_PARTITION_OFFSETS))) {
183+
List<Map<String, Object>> storedOffsets = (List<Map<String, Object>>) executionContext.get(
184+
getExecutionContextKey(TOPIC_PARTITION_OFFSETS));
185+
restorePartitionOffsets(storedOffsets);
186+
}
187+
188+
this.kafkaConsumer.assign(this.topicPartitions);
189+
this.partitionOffsets.forEach(this.kafkaConsumer::seek);
190+
}
191+
192+
/**
193+
* Initialize partition offsets with default values if not already set.
194+
*/
195+
private void initializePartitionOffsets() {
170196
if (this.partitionOffsets == null) {
171197
this.partitionOffsets = new HashMap<>();
172198
for (TopicPartition topicPartition : this.topicPartitions) {
173199
this.partitionOffsets.put(topicPartition, 0L);
174200
}
175201
}
176-
if (this.saveState && executionContext.containsKey(TOPIC_PARTITION_OFFSETS)) {
177-
Map<TopicPartition, Long> offsets = (Map<TopicPartition, Long>) executionContext
178-
.get(TOPIC_PARTITION_OFFSETS);
179-
for (Map.Entry<TopicPartition, Long> entry : offsets.entrySet()) {
180-
this.partitionOffsets.put(entry.getKey(), entry.getValue() == 0 ? 0 : entry.getValue() + 1);
202+
}
203+
204+
/**
205+
* Restore partition offsets from the stored list.
206+
* Each entry in the list contains topic, partition, and offset information.
207+
* @param storedOffsets the offsets stored in execution context
208+
*/
209+
private void restorePartitionOffsets(List<Map<String, Object>> storedOffsets) {
210+
for (Map<String, Object> offsetEntry : storedOffsets) {
211+
String topic = (String) offsetEntry.get(KEY_TOPIC);
212+
Number partition = (Number) offsetEntry.get(KEY_PARTITION);
213+
Number offset = (Number) offsetEntry.get(KEY_OFFSET);
214+
215+
if (topic == null || partition == null || offset == null) {
216+
log.warn("Incomplete offset entry found in execution context, skipping: " + offsetEntry);
217+
continue;
181218
}
219+
220+
TopicPartition topicPartition = new TopicPartition(topic, partition.intValue());
221+
long offsetValue = offset.longValue();
222+
this.partitionOffsets.put(topicPartition, offsetValue == 0 ? 0 : offsetValue + 1);
182223
}
183-
this.kafkaConsumer.assign(this.topicPartitions);
184-
this.partitionOffsets.forEach(this.kafkaConsumer::seek);
185224
}
186225

187226
@Nullable
@@ -202,8 +241,18 @@ public V read() {
202241

203242
@Override
204243
public void update(ExecutionContext executionContext) {
205-
if (this.saveState) {
206-
executionContext.put(TOPIC_PARTITION_OFFSETS, new HashMap<>(this.partitionOffsets));
244+
if (this.saveState && this.partitionOffsets != null) {
245+
List<Map<String, Object>> offsetsToStore = new ArrayList<>();
246+
247+
this.partitionOffsets.forEach((topicPartition, offset) -> {
248+
Map<String, Object> offsetEntry = new HashMap<>();
249+
offsetEntry.put(KEY_TOPIC, topicPartition.topic());
250+
offsetEntry.put(KEY_PARTITION, topicPartition.partition());
251+
offsetEntry.put(KEY_OFFSET, offset);
252+
offsetsToStore.add(offsetEntry);
253+
});
254+
255+
executionContext.put(getExecutionContextKey(TOPIC_PARTITION_OFFSETS), offsetsToStore);
207256
}
208257
this.kafkaConsumer.commitSync();
209258
}

spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemReaderTests.java

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,19 @@
1717
package org.springframework.batch.item.kafka;
1818

1919
import java.time.Duration;
20-
import java.util.Properties;
20+
import java.util.*;
2121

22+
import com.fasterxml.jackson.databind.ObjectMapper;
23+
import org.apache.kafka.clients.consumer.ConsumerConfig;
24+
import org.apache.kafka.common.TopicPartition;
2225
import org.apache.kafka.common.serialization.StringDeserializer;
2326
import org.junit.jupiter.api.Test;
27+
import org.mockito.MockedConstruction;
2428

25-
import static org.junit.jupiter.api.Assertions.assertEquals;
26-
import static org.junit.jupiter.api.Assertions.assertThrows;
29+
import org.springframework.batch.item.ExecutionContext;
30+
31+
import static org.junit.jupiter.api.Assertions.*;
32+
import static org.mockito.Mockito.mockConstruction;
2733

2834
/**
2935
* @author Mathieu Ouellet
@@ -77,4 +83,144 @@ void testValidation() {
7783
assertEquals("pollTimeout must not be negative", exception.getMessage());
7884
}
7985

86+
@Test
87+
void testExecutionContextSerializationWithJackson() throws Exception {
88+
Properties consumerProperties = new Properties();
89+
consumerProperties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "mockServer");
90+
consumerProperties.put(ConsumerConfig.GROUP_ID_CONFIG, "testGroup");
91+
consumerProperties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
92+
consumerProperties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
93+
94+
KafkaItemReader<String, String> reader = new KafkaItemReader<>(consumerProperties, "testTopic", 0, 1);
95+
reader.setName("kafkaItemReader");
96+
97+
// Simulate how Jackson would serialize/deserialize the offset data
98+
ExecutionContext executionContext = new ExecutionContext();
99+
List<Map<String, Object>> offsets = new ArrayList<>();
100+
101+
Map<String, Object> offset1 = new HashMap<>();
102+
offset1.put("topic", "testTopic");
103+
offset1.put("partition", 0);
104+
offset1.put("offset", 100L);
105+
offsets.add(offset1);
106+
107+
Map<String, Object> offset2 = new HashMap<>();
108+
offset2.put("topic", "testTopic");
109+
offset2.put("partition", 1);
110+
offset2.put("offset", 200L);
111+
offsets.add(offset2);
112+
113+
// Simulate Jackson serialization/deserialization
114+
ObjectMapper objectMapper = new ObjectMapper();
115+
String serialized = objectMapper.writeValueAsString(offsets);
116+
List<Map<String, Object>> deserializedOffsets = objectMapper.readValue(serialized, List.class);
117+
118+
executionContext.put("kafkaItemReader.topic.partition.offsets", deserializedOffsets);
119+
120+
try (MockedConstruction<org.apache.kafka.clients.consumer.KafkaConsumer> mockedConstruction = mockConstruction(
121+
org.apache.kafka.clients.consumer.KafkaConsumer.class)) {
122+
123+
reader.open(executionContext);
124+
125+
ExecutionContext newContext = new ExecutionContext();
126+
reader.update(newContext);
127+
128+
List<Map<String, Object>> savedOffsets = (List<Map<String, Object>>) newContext.get("kafkaItemReader.topic.partition.offsets");
129+
assertNotNull(savedOffsets);
130+
assertEquals(2, savedOffsets.size());
131+
132+
boolean foundPartition0 = false;
133+
boolean foundPartition1 = false;
134+
for (Map<String, Object> offsetEntry : savedOffsets) {
135+
String topic = (String) offsetEntry.get("topic");
136+
Integer partition = (Integer) offsetEntry.get("partition");
137+
Long offset = (Long) offsetEntry.get("offset");
138+
139+
assertEquals("testTopic", topic);
140+
assertNotNull(offset);
141+
142+
if (partition == 0) {
143+
foundPartition0 = true;
144+
assertEquals(101L, offset); // restored offset + 1
145+
} else if (partition == 1) {
146+
foundPartition1 = true;
147+
assertEquals(201L, offset); // restored offset + 1
148+
}
149+
}
150+
151+
assertTrue(foundPartition0);
152+
assertTrue(foundPartition1);
153+
}
154+
}
155+
156+
@Test
157+
void testExecutionContextWithStringKeys() throws Exception {
158+
Properties consumerProperties = new Properties();
159+
consumerProperties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "mockServer");
160+
consumerProperties.put(ConsumerConfig.GROUP_ID_CONFIG, "testGroup");
161+
consumerProperties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
162+
consumerProperties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
163+
164+
KafkaItemReader<String, String> reader = new KafkaItemReader<>(consumerProperties, "testTopic", 0, 1);
165+
reader.setName("kafkaItemReader");
166+
167+
// Create ExecutionContext with list of maps (as it would be after Jackson
168+
// deserialization)
169+
ExecutionContext executionContext = new ExecutionContext();
170+
List<Map<String, Object>> storedOffsets = new ArrayList<>();
171+
172+
Map<String, Object> offset1 = new HashMap<>();
173+
offset1.put("topic", "testTopic");
174+
offset1.put("partition", 0);
175+
offset1.put("offset", 100L);
176+
storedOffsets.add(offset1);
177+
178+
Map<String, Object> offset2 = new HashMap<>();
179+
offset2.put("topic", "testTopic");
180+
offset2.put("partition", 1);
181+
offset2.put("offset", 200L);
182+
storedOffsets.add(offset2);
183+
184+
executionContext.put("kafkaItemReader.topic.partition.offsets", storedOffsets);
185+
186+
try (MockedConstruction<org.apache.kafka.clients.consumer.KafkaConsumer> mockedConstruction = mockConstruction(
187+
org.apache.kafka.clients.consumer.KafkaConsumer.class)) {
188+
189+
reader.open(executionContext);
190+
191+
// Verify that offsets are saved correctly
192+
ExecutionContext newContext = new ExecutionContext();
193+
reader.update(newContext);
194+
195+
List<Map<String, Object>> savedOffsets = (List<Map<String, Object>>) newContext.get("kafkaItemReader.topic.partition.offsets");
196+
assertNotNull(savedOffsets);
197+
assertEquals(2, savedOffsets.size());
198+
}
199+
}
200+
201+
@Test
202+
void testJacksonSerializationFormat() throws Exception {
203+
// Test to verify the actual format when Jackson serializes our offset structure
204+
List<Map<String, Object>> offsets = new ArrayList<>();
205+
206+
Map<String, Object> offset1 = new HashMap<>();
207+
offset1.put("topic", "test-topic");
208+
offset1.put("partition", 0);
209+
offset1.put("offset", 100L);
210+
offsets.add(offset1);
211+
212+
ObjectMapper objectMapper = new ObjectMapper();
213+
String serialized = objectMapper.writeValueAsString(offsets);
214+
215+
// Verify the structure can be deserialized correctly
216+
List<Map<String, Object>> deserialized = objectMapper.readValue(serialized, List.class);
217+
assertEquals(1, deserialized.size());
218+
219+
Map<String, Object> deserializedOffset = deserialized.get(0);
220+
assertEquals("test-topic", deserializedOffset.get("topic"));
221+
// Jackson may deserialize numbers as Integer or Long depending on the value
222+
assertEquals(0, ((Number) deserializedOffset.get("partition")).intValue());
223+
assertEquals(100L, ((Number) deserializedOffset.get("offset")).longValue());
224+
}
225+
80226
}

0 commit comments

Comments
 (0)