|
41 | 41 | import org.apache.lucene.store.Directory;
|
42 | 42 | import org.apache.lucene.tests.index.RandomIndexWriter;
|
43 | 43 | import org.apache.lucene.util.NumericUtils;
|
| 44 | +import org.opensearch.Version; |
| 45 | +import org.opensearch.common.io.stream.BytesStreamOutput; |
| 46 | +import org.opensearch.core.common.io.stream.StreamInput; |
44 | 47 | import org.opensearch.index.mapper.MappedFieldType;
|
45 | 48 | import org.opensearch.index.mapper.NumberFieldMapper;
|
46 | 49 | import org.opensearch.plugins.SearchPlugin;
|
| 50 | +import org.opensearch.search.MultiValueMode; |
47 | 51 | import org.opensearch.search.aggregations.AggregatorTestCase;
|
48 | 52 | import org.opensearch.search.aggregations.matrix.MatrixAggregationModulePlugin;
|
49 | 53 |
|
| 54 | +import java.io.IOException; |
50 | 55 | import java.util.Arrays;
|
51 | 56 | import java.util.Collections;
|
52 | 57 | import java.util.List;
|
@@ -126,6 +131,90 @@ public void testTwoFields() throws Exception {
|
126 | 131 | }
|
127 | 132 | }
|
128 | 133 |
|
| 134 | + public void testMultiValueModeAffectsResult() throws Exception { |
| 135 | + String field = "grades"; |
| 136 | + MappedFieldType ft = new NumberFieldMapper.NumberFieldType(field, NumberFieldMapper.NumberType.DOUBLE); |
| 137 | + |
| 138 | + try (Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { |
| 139 | + Document doc = new Document(); |
| 140 | + doc.add(new SortedNumericDocValuesField(field, NumericUtils.doubleToSortableLong(1.0))); |
| 141 | + doc.add(new SortedNumericDocValuesField(field, NumericUtils.doubleToSortableLong(3.0))); |
| 142 | + doc.add(new SortedNumericDocValuesField(field, NumericUtils.doubleToSortableLong(5.0))); |
| 143 | + indexWriter.addDocument(doc); |
| 144 | + |
| 145 | + try (IndexReader reader = indexWriter.getReader()) { |
| 146 | + IndexSearcher searcher = new IndexSearcher(reader); |
| 147 | + |
| 148 | + MatrixStatsAggregationBuilder avgAgg = new MatrixStatsAggregationBuilder("avg_agg").fields(Collections.singletonList(field)) |
| 149 | + .multiValueMode(MultiValueMode.AVG); |
| 150 | + |
| 151 | + MatrixStatsAggregationBuilder minAgg = new MatrixStatsAggregationBuilder("min_agg").fields(Collections.singletonList(field)) |
| 152 | + .multiValueMode(MultiValueMode.MIN); |
| 153 | + |
| 154 | + InternalMatrixStats avgStats = searchAndReduce(searcher, new MatchAllDocsQuery(), avgAgg, ft); |
| 155 | + InternalMatrixStats minStats = searchAndReduce(searcher, new MatchAllDocsQuery(), minAgg, ft); |
| 156 | + |
| 157 | + double avg = avgStats.getMean(field); |
| 158 | + double min = minStats.getMean(field); |
| 159 | + |
| 160 | + assertNotEquals("AVG and MIN mode should yield different means", avg, min, 0.0001); |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + public void testSerializationDeserialization() throws IOException { |
| 166 | + MatrixStatsAggregationBuilder original = new MatrixStatsAggregationBuilder("test").fields(Collections.singletonList("field")) |
| 167 | + .multiValueMode(MultiValueMode.MIN); |
| 168 | + |
| 169 | + // Serialize |
| 170 | + BytesStreamOutput out = new BytesStreamOutput(); |
| 171 | + out.setVersion(Version.V_3_1_0); |
| 172 | + original.writeTo(out); |
| 173 | + |
| 174 | + // Deserialize |
| 175 | + StreamInput in = out.bytes().streamInput(); |
| 176 | + in.setVersion(Version.V_3_1_0); |
| 177 | + MatrixStatsAggregationBuilder deserialized = new MatrixStatsAggregationBuilder(in); |
| 178 | + |
| 179 | + assertEquals(original.getName(), deserialized.getName()); |
| 180 | + assertEquals(original.fields(), deserialized.fields()); |
| 181 | + assertEquals(original.multiValueMode(), deserialized.multiValueMode()); |
| 182 | + } |
| 183 | + |
| 184 | + public void testDeserializationFallbackToAvg() throws IOException { |
| 185 | + MatrixStatsAggregationBuilder original = new MatrixStatsAggregationBuilder("test").fields(Collections.singletonList("field")); |
| 186 | + |
| 187 | + // Serialize with V_2_3_0 (fallback required) |
| 188 | + BytesStreamOutput out = new BytesStreamOutput(); |
| 189 | + out.setVersion(Version.V_2_3_0); |
| 190 | + original.writeTo(out); |
| 191 | + |
| 192 | + StreamInput in = out.bytes().streamInput(); |
| 193 | + in.setVersion(Version.V_2_3_0); |
| 194 | + MatrixStatsAggregationBuilder deserialized = new MatrixStatsAggregationBuilder(in); |
| 195 | + |
| 196 | + assertEquals(MultiValueMode.AVG, deserialized.multiValueMode()); |
| 197 | + } |
| 198 | + |
| 199 | + public void testEqualsAndHashCode() { |
| 200 | + MatrixStatsAggregationBuilder agg1 = new MatrixStatsAggregationBuilder("agg").fields(Collections.singletonList("field")) |
| 201 | + .multiValueMode(MultiValueMode.AVG); |
| 202 | + |
| 203 | + MatrixStatsAggregationBuilder agg2 = new MatrixStatsAggregationBuilder("agg").fields(Collections.singletonList("field")) |
| 204 | + .multiValueMode(MultiValueMode.AVG); |
| 205 | + |
| 206 | + MatrixStatsAggregationBuilder agg3 = new MatrixStatsAggregationBuilder("agg").fields(Collections.singletonList("field")) |
| 207 | + .multiValueMode(MultiValueMode.MIN); |
| 208 | + |
| 209 | + // equals |
| 210 | + assertEquals(agg1, agg2); |
| 211 | + assertNotEquals(agg1, agg3); |
| 212 | + |
| 213 | + // hashCode |
| 214 | + assertEquals(agg1.hashCode(), agg2.hashCode()); |
| 215 | + assertNotEquals(agg1.hashCode(), agg3.hashCode()); |
| 216 | + } |
| 217 | + |
129 | 218 | @Override
|
130 | 219 | protected List<SearchPlugin> getSearchPlugins() {
|
131 | 220 | return Collections.singletonList(new MatrixAggregationModulePlugin());
|
|
0 commit comments