Skip to content

Commit 5540d5f

Browse files
authored
fix: handle composite map types (#182)
1 parent 997191c commit 5540d5f

File tree

4 files changed

+73
-134
lines changed

4 files changed

+73
-134
lines changed

dagger-common/src/main/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandler.java

Lines changed: 29 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import com.google.protobuf.Descriptors;
44
import com.google.protobuf.DynamicMessage;
5-
import com.google.protobuf.MapEntry;
6-
import com.google.protobuf.WireFormat;
7-
import io.odpf.dagger.common.serde.typehandler.TypeHandler;
85
import io.odpf.dagger.common.serde.typehandler.RowFactory;
6+
import io.odpf.dagger.common.serde.typehandler.TypeHandler;
7+
import io.odpf.dagger.common.serde.typehandler.TypeHandlerFactory;
98
import io.odpf.dagger.common.serde.typehandler.TypeInformationFactory;
9+
import io.odpf.dagger.common.serde.typehandler.repeated.RepeatedMessageHandler;
1010
import org.apache.flink.api.common.typeinfo.TypeInformation;
1111
import org.apache.flink.api.common.typeinfo.Types;
1212
import org.apache.flink.types.Row;
@@ -27,6 +27,7 @@
2727
public class MapHandler implements TypeHandler {
2828

2929
private Descriptors.FieldDescriptor fieldDescriptor;
30+
private TypeHandler repeatedMessageHandler;
3031

3132
/**
3233
* Instantiates a new Map proto handler.
@@ -35,6 +36,7 @@ public class MapHandler implements TypeHandler {
3536
*/
3637
public MapHandler(Descriptors.FieldDescriptor fieldDescriptor) {
3738
this.fieldDescriptor = fieldDescriptor;
39+
this.repeatedMessageHandler = new RepeatedMessageHandler(fieldDescriptor);
3840
}
3941

4042
@Override
@@ -47,38 +49,44 @@ public DynamicMessage.Builder transformToProtoBuilder(DynamicMessage.Builder bui
4749
if (!canHandle() || field == null) {
4850
return builder;
4951
}
50-
5152
if (field instanceof Map) {
52-
convertFromMap(builder, (Map<String, String>) field);
53-
}
54-
55-
if (field instanceof Object[]) {
56-
convertFromRow(builder, (Object[]) field);
53+
Map<?, ?> mapField = (Map<?, ?>) field;
54+
ArrayList<Row> rows = new ArrayList<>();
55+
for (Entry<?, ?> entry : mapField.entrySet()) {
56+
rows.add(Row.of(entry.getKey(), entry.getValue()));
57+
}
58+
return repeatedMessageHandler.transformToProtoBuilder(builder, rows.toArray());
5759
}
58-
59-
return builder;
60+
return repeatedMessageHandler.transformToProtoBuilder(builder, field);
6061
}
6162

6263
@Override
6364
public Object transformFromPostProcessor(Object field) {
6465
ArrayList<Row> rows = new ArrayList<>();
65-
if (field != null) {
66-
Map<String, String> mapField = (Map<String, String>) field;
67-
for (Entry<String, String> entry : mapField.entrySet()) {
68-
rows.add(getRowFromMap(entry));
66+
if (field == null) {
67+
return rows.toArray();
68+
}
69+
if (field instanceof Map) {
70+
Map<String, ?> mapField = (Map<String, ?>) field;
71+
for (Entry<String, ?> entry : mapField.entrySet()) {
72+
Descriptors.FieldDescriptor keyDescriptor = fieldDescriptor.getMessageType().findFieldByName("key");
73+
Descriptors.FieldDescriptor valueDescriptor = fieldDescriptor.getMessageType().findFieldByName("value");
74+
TypeHandler handler = TypeHandlerFactory.getTypeHandler(keyDescriptor);
75+
Object key = handler.transformFromPostProcessor(entry.getKey());
76+
Object value = TypeHandlerFactory.getTypeHandler(valueDescriptor).transformFromPostProcessor(entry.getValue());
77+
rows.add(Row.of(key, value));
6978
}
79+
return rows.toArray();
80+
}
81+
if (field instanceof List) {
82+
return repeatedMessageHandler.transformFromPostProcessor(field);
7083
}
7184
return rows.toArray();
7285
}
7386

7487
@Override
7588
public Object transformFromProto(Object field) {
76-
ArrayList<Row> rows = new ArrayList<>();
77-
if (field != null) {
78-
List<DynamicMessage> protos = (List<DynamicMessage>) field;
79-
protos.forEach(proto -> rows.add(getRowFromMap(proto)));
80-
}
81-
return rows.toArray();
89+
return repeatedMessageHandler.transformFromProto(field);
8290
}
8391

8492
@Override
@@ -127,53 +135,4 @@ public Object transformToJson(Object field) {
127135
public TypeInformation getTypeInformation() {
128136
return Types.OBJECT_ARRAY(TypeInformationFactory.getRowType(fieldDescriptor.getMessageType()));
129137
}
130-
131-
private Row getRowFromMap(Entry<String, String> entry) {
132-
Row row = new Row(2);
133-
row.setField(0, entry.getKey());
134-
row.setField(1, entry.getValue());
135-
return row;
136-
}
137-
138-
private Row getRowFromMap(DynamicMessage proto) {
139-
Row row = new Row(2);
140-
row.setField(0, parse(proto, "key"));
141-
row.setField(1, parse(proto, "value"));
142-
return row;
143-
}
144-
145-
private Object parse(DynamicMessage proto, String fieldName) {
146-
Object field = proto.getField(proto.getDescriptorForType().findFieldByName(fieldName));
147-
if (DynamicMessage.class.equals(field.getClass())) {
148-
field = RowFactory.createRow((DynamicMessage) field);
149-
}
150-
return field;
151-
}
152-
153-
private void convertFromRow(DynamicMessage.Builder builder, Object[] field) {
154-
for (Object inputValue : field) {
155-
Row inputRow = (Row) inputValue;
156-
if (inputRow.getArity() != 2) {
157-
throw new IllegalArgumentException("Row: " + inputRow.toString() + " of size: " + inputRow.getArity() + " cannot be converted to map");
158-
}
159-
MapEntry<String, String> mapEntry = MapEntry
160-
.newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, "");
161-
builder.addRepeatedField(fieldDescriptor,
162-
mapEntry.toBuilder()
163-
.setKey((String) inputRow.getField(0))
164-
.setValue((String) inputRow.getField(1))
165-
.buildPartial());
166-
}
167-
}
168-
169-
private void convertFromMap(DynamicMessage.Builder builder, Map<String, String> field) {
170-
for (Entry<String, String> entry : field.entrySet()) {
171-
MapEntry<String, String> mapEntry = MapEntry.newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, "");
172-
builder.addRepeatedField(fieldDescriptor,
173-
mapEntry.toBuilder()
174-
.setKey(entry.getKey())
175-
.setValue(entry.getValue())
176-
.buildPartial());
177-
}
178-
}
179138
}

dagger-common/src/test/java/io/odpf/dagger/common/serde/typehandler/complex/MapHandlerTest.java

Lines changed: 38 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.apache.parquet.schema.LogicalTypeAnnotation;
1717
import org.apache.parquet.schema.MessageType;
1818
import org.apache.parquet.schema.PrimitiveType;
19-
import org.junit.Assert;
2019
import org.junit.Test;
2120

2221
import java.util.ArrayList;
@@ -81,13 +80,11 @@ public void shouldSetMapFieldIfStringMapPassed() {
8180
inputMap.put("b", "456");
8281

8382
DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputMap);
84-
List<MapEntry> entries = (List<MapEntry>) returnedBuilder.getField(mapFieldDescriptor);
83+
List<DynamicMessage> entries = (List<DynamicMessage>) returnedBuilder.getField(mapFieldDescriptor);
8584

8685
assertEquals(2, entries.size());
87-
assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]);
88-
assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]);
89-
assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]);
90-
assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]);
86+
assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray());
87+
assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray());
9188
}
9289

9390
@Test
@@ -111,28 +108,29 @@ public void shouldSetMapFieldIfArrayofObjectsHavingRowsWithStringFieldsPassed()
111108
inputRows.add(inputRow2);
112109

113110
DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputRows.toArray());
114-
List<MapEntry> entries = (List<MapEntry>) returnedBuilder.getField(mapFieldDescriptor);
111+
List<DynamicMessage> entries = (List<DynamicMessage>) returnedBuilder.getField(mapFieldDescriptor);
115112

116113
assertEquals(2, entries.size());
117-
assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]);
118-
assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]);
119-
assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]);
120-
assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]);
114+
assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray());
115+
assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray());
121116
}
122117

123118
@Test
124-
public void shouldThrowExceptionIfRowsPassedAreNotOfArityTwo() {
125-
Descriptors.FieldDescriptor mapFieldDescriptor = TestBookingLogMessage.getDescriptor().findFieldByName("metadata");
126-
MapHandler mapHandler = new MapHandler(mapFieldDescriptor);
127-
DynamicMessage.Builder builder = DynamicMessage.newBuilder(mapFieldDescriptor.getContainingType());
119+
public void shouldHandleComplexTypeValuesForSerialization() throws InvalidProtocolBufferException {
120+
Row inputValue1 = Row.of("12345", Row.of(Arrays.asList("a", "b")));
121+
Row inputValue2 = Row.of(1234123, Row.of(Arrays.asList("d", "e")));
122+
Object input = Arrays.asList(inputValue1, inputValue2).toArray();
128123

129-
ArrayList<Row> inputRows = new ArrayList<>();
124+
Descriptors.FieldDescriptor intMessageDescriptor = TestComplexMap.getDescriptor().findFieldByName("int_message");
125+
DynamicMessage.Builder builder = DynamicMessage.newBuilder(TestComplexMap.getDescriptor());
130126

131-
Row inputRow = new Row(3);
132-
inputRows.add(inputRow);
133-
IllegalArgumentException exception = Assert.assertThrows(IllegalArgumentException.class,
134-
() -> mapHandler.transformToProtoBuilder(builder, inputRows.toArray()));
135-
assertEquals("Row: +I[null, null, null] of size: 3 cannot be converted to map", exception.getMessage());
127+
byte[] data = new MapHandler(intMessageDescriptor).transformToProtoBuilder(builder, input).build().toByteArray();
128+
TestComplexMap actualMsg = TestComplexMap.parseFrom(data);
129+
assertArrayEquals(Arrays.asList(12345L, 1234123L).toArray(), actualMsg.getIntMessageMap().keySet().toArray());
130+
TestComplexMap.IdMessage idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[0];
131+
assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("a", "b")));
132+
idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[1];
133+
assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("d", "e")));
136134
}
137135

138136
@Test
@@ -158,12 +156,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf
158156

159157
List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromPostProcessor(inputMap));
160158

161-
assertEquals("a", ((Row) outputValues.get(0)).getField(0));
162-
assertEquals("123", ((Row) outputValues.get(0)).getField(1));
163-
assertEquals(2, ((Row) outputValues.get(0)).getArity());
164-
assertEquals("b", ((Row) outputValues.get(1)).getField(0));
165-
assertEquals("456", ((Row) outputValues.get(1)).getField(1));
166-
assertEquals(2, ((Row) outputValues.get(1)).getArity());
159+
assertEquals(Row.of("a", "123"), outputValues.get(0));
160+
assertEquals(Row.of("b", "456"), outputValues.get(1));
167161
}
168162

169163
@Test
@@ -210,12 +204,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf
210204

211205
List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));
212206

213-
assertEquals("a", ((Row) outputValues.get(0)).getField(0));
214-
assertEquals("123", ((Row) outputValues.get(0)).getField(1));
215-
assertEquals(2, ((Row) outputValues.get(0)).getArity());
216-
assertEquals("b", ((Row) outputValues.get(1)).getField(0));
217-
assertEquals("456", ((Row) outputValues.get(1)).getField(1));
218-
assertEquals(2, ((Row) outputValues.get(1)).getArity());
207+
assertEquals(Row.of("a", "123"), outputValues.get(0));
208+
assertEquals(Row.of("b", "456"), outputValues.get(1));
219209
}
220210

221211
@Test
@@ -247,16 +237,11 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie
247237

248238
List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));
249239

250-
assertEquals(1, ((Row) outputValues.get(0)).getField(0));
251-
assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
252-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
253-
assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
254-
assertEquals(2, ((Row) outputValues.get(0)).getArity());
255-
assertEquals(2, ((Row) outputValues.get(1)).getField(0));
256-
assertEquals("456", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(0));
257-
assertEquals("", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(1));
258-
assertEquals("efg", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(2));
259-
assertEquals(2, ((Row) outputValues.get(1)).getArity());
240+
Row mapEntry1 = Row.of(1, Row.of("123", "", "abc"));
241+
Row mapEntry2 = Row.of(2, Row.of("456", "", "efg"));
242+
243+
assertEquals(mapEntry1, outputValues.get(0));
244+
assertEquals(mapEntry2, outputValues.get(1));
260245
}
261246

262247
@Test
@@ -271,11 +256,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie
271256

272257
List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));
273258

274-
assertEquals(0, ((Row) outputValues.get(0)).getField(0));
275-
assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
276-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
277-
assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
278-
assertEquals(2, ((Row) outputValues.get(0)).getArity());
259+
Row expected = Row.of(0, Row.of("123", "", "abc"));
260+
assertEquals(expected, outputValues.get(0));
279261
}
280262

281263
@Test
@@ -290,11 +272,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie
290272

291273
List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));
292274

293-
assertEquals(1, ((Row) outputValues.get(0)).getField(0));
294-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
295-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
296-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
297-
assertEquals(2, ((Row) outputValues.get(0)).getArity());
275+
Row expected = Row.of(1, Row.of("", "", ""));
276+
277+
assertEquals(expected, outputValues.get(0));
298278
}
299279

300280
@Test
@@ -309,11 +289,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie
309289

310290
List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));
311291

312-
assertEquals(0, ((Row) outputValues.get(0)).getField(0));
313-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
314-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
315-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
316-
assertEquals(2, ((Row) outputValues.get(0)).getArity());
292+
Row expected = Row.of(0, Row.of("", "", ""));
293+
294+
assertEquals(expected, outputValues.get(0));
317295
}
318296

319297
@Test
@@ -328,11 +306,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie
328306

329307
List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));
330308

331-
assertEquals(0, ((Row) outputValues.get(0)).getField(0));
332-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
333-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
334-
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
335-
assertEquals(2, ((Row) outputValues.get(0)).getArity());
309+
Row expected = Row.of(0, Row.of("", "", ""));
310+
assertEquals(expected, outputValues.get(0));
336311
}
337312

338313
@Test

dagger-common/src/test/proto/TestMessage.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ message TestEnumMessage {
5353
}
5454

5555
message TestComplexMap {
56+
message IdMessage {
57+
repeated string ids = 1;
58+
}
5659
map<int32, TestMessage> complex_map = 1;
60+
map<int64, IdMessage> int_message = 2;
61+
map<string, IdMessage> string_message = 3;
5762
}
5863

5964
message TestRepeatedPrimitiveMessage {

dagger-core/src/test/java/io/odpf/dagger/core/processors/internal/processor/function/functions/JsonPayloadFunctionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ public void shouldGetCorrectJsonPayloadForComplexFields() throws InvalidProtocol
269269
DynamicMessage dynamicMessage = DynamicMessage.parseFrom(complexMapMessage.getDescriptor(), complexMapMessage.toByteArray());
270270
RowManager rowManager = getRowManagerForMessage(dynamicMessage);
271271

272-
String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}]}";
272+
String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}],\"int_message\":[],\"string_message\":[]}";
273273
String actualJsonPayload = (String) jsonPayloadFunction.getResult(rowManager);
274274

275275
assertEquals(expectedJsonPayload, actualJsonPayload);

0 commit comments

Comments
 (0)