diff --git a/libs/core/src/main/java/org/opensearch/core/common/bytes/BytesReference.java b/libs/core/src/main/java/org/opensearch/core/common/bytes/BytesReference.java index 9d24d3653397b..5b53e1d9f8bae 100644 --- a/libs/core/src/main/java/org/opensearch/core/common/bytes/BytesReference.java +++ b/libs/core/src/main/java/org/opensearch/core/common/bytes/BytesReference.java @@ -45,6 +45,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.io.Serializable; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -54,7 +55,8 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public interface BytesReference extends Comparable, ToXContentFragment { +public interface BytesReference extends Comparable, ToXContentFragment, Serializable { + // TODO: Remove "Serializable" once we merge in the serializer PR! /** * Convert an {@link XContentBuilder} into a BytesReference. This method closes the builder, diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheDiskTierIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheDiskTierIT.java index 5da0b545e215f..a54d8d06c2119 100644 --- a/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheDiskTierIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheDiskTierIT.java @@ -35,8 +35,10 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.cache.tier.DiskTierTookTimePolicy; import org.opensearch.common.cache.tier.TierType; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.index.cache.request.RequestCacheStats; import org.opensearch.index.cache.request.ShardRequestCache; @@ -54,7 +56,9 @@ public class IndicesRequestCacheDiskTierIT extends OpenSearchIntegTestCase { public void testDiskTierStats() throws Exception { int heapSizeBytes = 4729; String node = internalCluster().startNode( - Settings.builder().put(IndicesRequestCache.INDICES_CACHE_QUERY_SIZE.getKey(), new ByteSizeValue(heapSizeBytes)) + Settings.builder() + .put(IndicesRequestCache.INDICES_CACHE_QUERY_SIZE.getKey(), new ByteSizeValue(heapSizeBytes)) + .put(DiskTierTookTimePolicy.DISK_TOOKTIME_THRESHOLD_SETTING.getKey(), TimeValue.ZERO) // allow into disk cache regardless of took time ); Client client = client(node); diff --git a/server/src/main/java/org/opensearch/common/cache/tier/CachePolicyInfoWrapper.java b/server/src/main/java/org/opensearch/common/cache/tier/CachePolicyInfoWrapper.java new file mode 100644 index 0000000000000..6a077715e8232 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/tier/CachePolicyInfoWrapper.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.cache.tier; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +import java.io.IOException; + +/** + * A class containing information needed for all CacheTierPolicy objects to decide whether to admit + * a given BytesReference. This spares us from having to create an entire short-lived QuerySearchResult object + * just to read a few values. + */ +public class CachePolicyInfoWrapper implements Writeable { + private final Long tookTimeNanos; + public CachePolicyInfoWrapper(Long tookTimeNanos) { + this.tookTimeNanos = tookTimeNanos; + // Add more values here as they are needed for future cache tier policies + } + + public CachePolicyInfoWrapper(StreamInput in) throws IOException { + this.tookTimeNanos = in.readOptionalLong(); + } + + public Long getTookTimeNanos() { + return tookTimeNanos; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalLong(tookTimeNanos); + } +} diff --git a/server/src/main/java/org/opensearch/common/cache/tier/CacheTierPolicy.java b/server/src/main/java/org/opensearch/common/cache/tier/CacheTierPolicy.java new file mode 100644 index 0000000000000..1b5cd2c064397 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/tier/CacheTierPolicy.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.cache.tier; + +public interface CacheTierPolicy { + /** + * Determines whether this policy allows the data into its cache tier. + * @param data The data to check + * @return true if accepted, otherwise false + */ + boolean checkData(T data); +} diff --git a/server/src/main/java/org/opensearch/common/cache/tier/DiskTierTookTimePolicy.java b/server/src/main/java/org/opensearch/common/cache/tier/DiskTierTookTimePolicy.java new file mode 100644 index 0000000000000..4a173624298a2 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/tier/DiskTierTookTimePolicy.java @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.cache.tier; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.cache.tier.CacheTierPolicy; +import org.opensearch.search.query.QuerySearchResult; + +import java.io.IOException; +import java.util.function.Function; + +/** + * A cache tier policy which accepts queries whose took time is greater than some threshold, + * which is specified as a dynamic cluster-level setting. The threshold should be set to approximately + * the time it takes to get a result from the cache tier. + * The policy expects to be able to read a CachePolicyInfoWrapper from the start of the BytesReference. + */ +public class DiskTierTookTimePolicy implements CacheTierPolicy { + public static final Setting DISK_TOOKTIME_THRESHOLD_SETTING = Setting.positiveTimeSetting( + "indices.requests.cache.disk.tooktime.threshold", + new TimeValue(10), + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); // Set this to TimeValue.ZERO to let all data through + + private TimeValue threshold; + private final Function getPolicyInfoFn; + + public DiskTierTookTimePolicy(Settings settings, ClusterSettings clusterSettings, Function getPolicyInfoFn) { + this.threshold = DISK_TOOKTIME_THRESHOLD_SETTING.get(settings); + clusterSettings.addSettingsUpdateConsumer(DISK_TOOKTIME_THRESHOLD_SETTING, this::setThreshold); + this.getPolicyInfoFn = getPolicyInfoFn; + } + + protected void setThreshold(TimeValue threshold) { // protected so that we can manually set value in unit test + this.threshold = threshold; + } + + @Override + public boolean checkData(BytesReference data) { + if (threshold.equals(TimeValue.ZERO)) { + return true; + } + Long tookTimeNanos; + try { + tookTimeNanos = getPolicyInfoFn.apply(data).getTookTimeNanos(); + } catch (Exception e) { + // If we can't retrieve the took time for whatever reason, admit the data to be safe + return true; + } + if (tookTimeNanos == null) { + // Received a null took time -> this QSR is from an old version which does not have took time, we should accept it + return true; + } + TimeValue tookTime = TimeValue.timeValueNanos(tookTimeNanos); + if (tookTime.compareTo(threshold) < 0) { // negative -> tookTime is shorter than threshold + return false; + } + return true; + } +} diff --git a/server/src/main/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyService.java b/server/src/main/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyService.java index f8e037515ae6d..78e4cc5e11f48 100644 --- a/server/src/main/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyService.java +++ b/server/src/main/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyService.java @@ -8,10 +8,15 @@ package org.opensearch.common.cache.tier; +import org.opensearch.common.cache.Cache; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; import org.opensearch.common.cache.RemovalReason; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.indices.IndicesRequestCache; +import org.opensearch.search.query.QuerySearchResult; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -38,6 +43,7 @@ public class TieredCacheSpilloverStrategyService implements TieredCacheSer * Maintains caching tiers in order of get calls. */ private final List> cachingTierList; + private final List> policies; private TieredCacheSpilloverStrategyService(Builder builder) { this.onHeapCachingTier = Objects.requireNonNull(builder.onHeapCachingTier); @@ -45,6 +51,7 @@ private TieredCacheSpilloverStrategyService(Builder builder) { this.tieredCacheEventListener = Objects.requireNonNull(builder.tieredCacheEventListener); this.cachingTierList = this.diskCachingTier.map(diskTier -> Arrays.asList(onHeapCachingTier, diskTier)) .orElse(List.of(onHeapCachingTier)); + this.policies = Objects.requireNonNull(builder.policies); setRemovalListeners(); } @@ -130,10 +137,12 @@ public void onRemoval(RemovalNotification notification) { if (RemovalReason.EVICTED.equals(notification.getRemovalReason())) { switch (notification.getTierType()) { case ON_HEAP: - diskCachingTier.ifPresent(diskTier -> { - diskTier.put(notification.getKey(), notification.getValue()); - tieredCacheEventListener.onCached(notification.getKey(), notification.getValue(), TierType.DISK); - }); + if (checkPolicies(notification.getValue())) { + diskCachingTier.ifPresent(diskTier -> { + diskTier.put(notification.getKey(), notification.getValue()); + tieredCacheEventListener.onCached(notification.getKey(), notification.getValue(), TierType.DISK); + }); + } break; default: break; @@ -152,6 +161,15 @@ public Optional> getDiskCachingTier() { return this.diskCachingTier; } + boolean checkPolicies(V value) { + for (CacheTierPolicy policy : policies) { + if (!policy.checkData(value)) { + return false; + } + } + return true; + } + /** * Register this service as a listener to removal events from different caching tiers. */ @@ -190,6 +208,7 @@ public static class Builder { private OnHeapCachingTier onHeapCachingTier; private DiskCachingTier diskCachingTier; private TieredCacheEventListener tieredCacheEventListener; + private ArrayList> policies = new ArrayList<>(); public Builder() {} @@ -208,6 +227,17 @@ public Builder setTieredCacheEventListener(TieredCacheEventListener return this; } + public Builder withPolicy(CacheTierPolicy policy) { + this.policies.add(policy); + return this; + } + + // Add multiple policies at once + public Builder withPolicies(List> policiesList) { + this.policies.addAll(policiesList); + return this; + } + public TieredCacheSpilloverStrategyService build() { return new TieredCacheSpilloverStrategyService(this); } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index 4cd3490cffb4c..325f28875afa1 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -79,6 +79,7 @@ import org.opensearch.cluster.service.ClusterManagerTaskThrottler; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.cache.tier.DiskTierTookTimePolicy; import org.opensearch.common.logging.Loggers; import org.opensearch.common.network.NetworkModule; import org.opensearch.common.network.NetworkService; @@ -675,7 +676,10 @@ public void apply(Settings value, Settings current, Settings previous) { RemoteClusterStateService.REMOTE_CLUSTER_STATE_ENABLED_SETTING, RemoteStoreNodeService.REMOTE_STORE_COMPATIBILITY_MODE_SETTING, IndicesService.CLUSTER_REMOTE_TRANSLOG_BUFFER_INTERVAL_SETTING, - IndicesService.CLUSTER_REMOTE_INDEX_RESTRICT_ASYNC_DURABILITY_SETTING + IndicesService.CLUSTER_REMOTE_INDEX_RESTRICT_ASYNC_DURABILITY_SETTING, + + // Tiered caching + DiskTierTookTimePolicy.DISK_TOOKTIME_THRESHOLD_SETTING ) ) ); diff --git a/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java b/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java index cb1c94fbd7e29..5530f36201ddf 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java +++ b/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java @@ -40,6 +40,8 @@ import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.CheckedSupplier; import org.opensearch.common.cache.RemovalNotification; +import org.opensearch.common.cache.tier.CachePolicyInfoWrapper; +import org.opensearch.common.cache.tier.DiskTierTookTimePolicy; import org.opensearch.common.cache.tier.BytesReferenceSerializer; import org.opensearch.common.cache.tier.CacheValue; import org.opensearch.common.cache.tier.EhCacheDiskCachingTier; @@ -51,6 +53,7 @@ import org.opensearch.common.cache.tier.TieredCacheService; import org.opensearch.common.cache.tier.TieredCacheSpilloverStrategyService; import org.opensearch.common.lucene.index.OpenSearchDirectoryReader; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Setting.Property; import org.opensearch.common.settings.Settings; @@ -61,6 +64,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.search.query.QuerySearchResult; import java.io.Closeable; import java.io.IOException; @@ -71,6 +75,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentMap; +import java.util.function.Function; /** * The indices request cache allows to cache a shard level request stage responses, helping with improving @@ -120,7 +125,7 @@ public final class IndicesRequestCache implements TieredCacheEventListener tieredCacheServiceBuilder = new TieredCacheSpilloverStrategyService.Builder() @@ -142,6 +148,16 @@ public final class IndicesRequestCache implements TieredCacheEventListener ehcacheDiskTier = createNewDiskTier(); tieredCacheServiceBuilder.setOnDiskCachingTier(ehcacheDiskTier); + + // Function to allow took-time policy to inspect took time on cached data. + Function transformationFunction = (data) -> { + try { + return getPolicyInfo(data); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + tieredCacheServiceBuilder.withPolicy(new DiskTierTookTimePolicy(settings, clusterSettings, transformationFunction)); tieredCacheService = tieredCacheServiceBuilder.build(); } @@ -223,6 +239,12 @@ void invalidate(CacheEntity cacheEntity, DirectoryReader reader, BytesReference tieredCacheService.invalidate(new Key(cacheEntity, cacheKey, readerCacheKeyId)); } + public static CachePolicyInfoWrapper getPolicyInfo(BytesReference data) throws IOException { + // Reads the policy info corresponding to this QSR, written in IndicesService$loadIntoContext, + // without having to create a potentially large short-lived QSR object just for this purpose + return new CachePolicyInfoWrapper(data.streamInput()); + } + /** * Loader for the request cache * diff --git a/server/src/main/java/org/opensearch/indices/IndicesService.java b/server/src/main/java/org/opensearch/indices/IndicesService.java index d1f39c9a567e5..44ae73bc24315 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesService.java +++ b/server/src/main/java/org/opensearch/indices/IndicesService.java @@ -61,6 +61,8 @@ import org.opensearch.common.CheckedFunction; import org.opensearch.common.CheckedSupplier; import org.opensearch.common.Nullable; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.tier.CachePolicyInfoWrapper; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lifecycle.AbstractLifecycleComponent; @@ -391,7 +393,7 @@ public IndicesService( this.shardsClosedTimeout = settings.getAsTime(INDICES_SHARDS_CLOSED_TIMEOUT, new TimeValue(1, TimeUnit.DAYS)); this.analysisRegistry = analysisRegistry; this.indexNameExpressionResolver = indexNameExpressionResolver; - this.indicesRequestCache = new IndicesRequestCache(settings, this); + this.indicesRequestCache = new IndicesRequestCache(settings, this, clusterService.getClusterSettings()); this.indicesQueryCache = new IndicesQueryCache(settings); this.mapperRegistry = mapperRegistry; this.namedWriteableRegistry = namedWriteableRegistry; @@ -1674,6 +1676,10 @@ public void loadIntoContext(ShardSearchRequest request, SearchContext context, Q boolean[] loadedFromCache = new boolean[] { true }; BytesReference bytesReference = cacheShardLevelResult(context.indexShard(), directoryReader, request.cacheKey(), out -> { queryPhase.execute(context); + CachePolicyInfoWrapper policyInfo = new CachePolicyInfoWrapper(context.queryResult().getTookTimeNanos()); + policyInfo.writeTo(out); + // Write relevant info for cache tier policies before the whole QuerySearchResult, so we don't have to read + // the whole QSR into memory when we decide whether to allow it into a particular cache tier based on took time/other info context.queryResult().writeToNoId(out); loadedFromCache[0] = false; }); @@ -1682,6 +1688,7 @@ public void loadIntoContext(ShardSearchRequest request, SearchContext context, Q // restore the cached query result into the context final QuerySearchResult result = context.queryResult(); StreamInput in = new NamedWriteableAwareStreamInput(bytesReference.streamInput(), namedWriteableRegistry); + CachePolicyInfoWrapper policyInfo = new CachePolicyInfoWrapper(in); // This wrapper is not needed outside the cache result.readFromWithId(context.id(), in); result.setSearchShardTarget(context.shardTarget()); } else if (context.queryResult().searchTimedOut()) { diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index f3cf2c13ecdef..e6c281fdf74f4 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -131,6 +131,7 @@ public void preProcess(SearchContext context) { } public void execute(SearchContext searchContext) throws QueryPhaseExecutionException { + final long startTime = System.nanoTime(); if (searchContext.hasOnlySuggest()) { suggestProcessor.process(searchContext); searchContext.queryResult() @@ -138,6 +139,7 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN), new DocValueFormat[0] ); + searchContext.queryResult().setTookTimeNanos(System.nanoTime() - startTime); return; } @@ -165,6 +167,7 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep ); searchContext.queryResult().profileResults(shardResults); } + searchContext.queryResult().setTookTimeNanos(System.nanoTime() - startTime); } // making public for testing @@ -292,7 +295,6 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q queryResult.nodeQueueSize(rExecutor.getCurrentQueueSize()); queryResult.serviceTimeEWMA((long) rExecutor.getTaskExecutionEWMA()); } - return shouldRescore; } finally { // Search phase has finished, no longer need to check for timeout diff --git a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java index 7de605a244d09..74e9dfc97357d 100644 --- a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java @@ -34,6 +34,7 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.TotalHits; +import org.opensearch.Version; import org.opensearch.common.io.stream.DelayableWriteable; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.common.io.stream.StreamInput; @@ -87,6 +88,7 @@ public final class QuerySearchResult extends SearchPhaseResult { private int nodeQueueSize = -1; private final boolean isNull; + private Long tookTimeNanos = null; public QuerySearchResult() { this(false); @@ -364,6 +366,11 @@ public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOExc nodeQueueSize = in.readInt(); setShardSearchRequest(in.readOptionalWriteable(ShardSearchRequest::new)); setRescoreDocIds(new RescoreDocIds(in)); + if (in.getVersion().onOrAfter(Version.V_3_0_0)) { + tookTimeNanos = in.readOptionalLong(); + } else { + tookTimeNanos = null; + } } @Override @@ -406,6 +413,9 @@ public void writeToNoId(StreamOutput out) throws IOException { out.writeInt(nodeQueueSize); out.writeOptionalWriteable(getShardSearchRequest()); getRescoreDocIds().writeTo(out); + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { + out.writeOptionalLong(tookTimeNanos); + } } public TotalHits getTotalHits() { @@ -415,4 +425,12 @@ public TotalHits getTotalHits() { public float getMaxScore() { return maxScore; } + + public Long getTookTimeNanos() { + return tookTimeNanos; + } + + public void setTookTimeNanos(long tookTime) { + tookTimeNanos = tookTime; + } } diff --git a/server/src/test/java/org/opensearch/common/cache/tier/DiskTierTookTimePolicyTests.java b/server/src/test/java/org/opensearch/common/cache/tier/DiskTierTookTimePolicyTests.java new file mode 100644 index 0000000000000..1c5d5ce71fc73 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/cache/tier/DiskTierTookTimePolicyTests.java @@ -0,0 +1,144 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.cache.tier; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.OriginalIndicesTests; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.UUIDs; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.indices.IndicesRequestCache; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.function.Function; + +public class DiskTierTookTimePolicyTests extends OpenSearchTestCase { + private final Function transformationFunction = (data) -> { + try { + return IndicesRequestCache.getPolicyInfo(data); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + private DiskTierTookTimePolicy getTookTimePolicy() { + // dummy settings + Settings dummySettings = Settings.EMPTY; + ClusterSettings dummyClusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + return new DiskTierTookTimePolicy(dummySettings, dummyClusterSettings, transformationFunction); + } + + public void testQSRSetupFunction() throws IOException { + Long ttn = 100000000000L; + QuerySearchResult qsr = getQSR(ttn); + assertEquals(ttn, qsr.getTookTimeNanos()); + } + public void testTookTimePolicy() throws Exception { + DiskTierTookTimePolicy tookTimePolicy = getTookTimePolicy(); + + // manually set threshold for test + double threshMillis = 10; + long shortMillis = (long) (0.9 * threshMillis); + long longMillis = (long) (1.5 * threshMillis); + tookTimePolicy.setThreshold(new TimeValue((long) threshMillis)); + BytesReference shortTime = getValidPolicyInput(getQSR(shortMillis * 1000000)); + BytesReference longTime = getValidPolicyInput(getQSR(longMillis * 1000000)); + + boolean shortResult = tookTimePolicy.checkData(shortTime); + assertFalse(shortResult); + boolean longResult = tookTimePolicy.checkData(longTime); + assertTrue(longResult); + + DiskTierTookTimePolicy disabledPolicy = getTookTimePolicy(); + disabledPolicy.setThreshold(TimeValue.ZERO); + shortResult = disabledPolicy.checkData(shortTime); + assertTrue(shortResult); + longResult = disabledPolicy.checkData(longTime); + assertTrue(longResult); + } + + public static QuerySearchResult getQSR(long tookTimeNanos) { + // package-private, also used by IndicesRequestCacheTests.java + // setup from QuerySearchResultTests.java + ShardId shardId = new ShardId("index", "uuid", randomInt()); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()); + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + OriginalIndicesTests.randomOriginalIndices(), + searchRequest, + shardId, + 1, + new AliasFilter(null, Strings.EMPTY_ARRAY), + 1.0f, + randomNonNegativeLong(), + null, + new String[0] + ); + ShardSearchContextId id = new ShardSearchContextId(UUIDs.base64UUID(), randomLong()); + QuerySearchResult result = new QuerySearchResult( + id, + new SearchShardTarget("node", shardId, null, OriginalIndices.NONE), + shardSearchRequest + ); + TopDocs topDocs = new TopDocs(new TotalHits(randomLongBetween(0, Long.MAX_VALUE), TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + result.topDocs(new TopDocsAndMaxScore(topDocs, randomBoolean() ? Float.NaN : randomFloat()), new DocValueFormat[0]); + + result.setTookTimeNanos(tookTimeNanos); + return result; + } + + private BytesReference getValidPolicyInput(QuerySearchResult qsr) throws IOException { + // When it's used in the cache, the policy will receive BytesReferences which have a CachePolicyInfoWrapper + // at the beginning of them, followed by the actual QSR. + CachePolicyInfoWrapper policyInfo = new CachePolicyInfoWrapper(qsr.getTookTimeNanos()); + BytesStreamOutput out = new BytesStreamOutput(); + policyInfo.writeTo(out); + qsr.writeTo(out); + return out.bytes(); + } +} diff --git a/server/src/test/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyServiceTests.java b/server/src/test/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyServiceTests.java index a85d82118ff66..f0b286b13b200 100644 --- a/server/src/test/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyServiceTests.java +++ b/server/src/test/java/org/opensearch/common/cache/tier/TieredCacheSpilloverStrategyServiceTests.java @@ -15,11 +15,15 @@ import org.opensearch.test.OpenSearchTestCase; import java.util.ArrayList; +import java.util.Arrays; import java.util.EnumMap; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; public class TieredCacheSpilloverStrategyServiceTests extends OpenSearchTestCase { @@ -29,7 +33,8 @@ public void testComputeAndAbsentWithoutAnyOnHeapCacheEviction() throws Exception TieredCacheSpilloverStrategyService spilloverStrategyService = intializeTieredCacheService( onHeapCacheSize, randomIntBetween(1, 4), - eventListener + eventListener, + null ); int numOfItems1 = randomIntBetween(1, onHeapCacheSize / 2 - 1); List keys = new ArrayList<>(); @@ -73,7 +78,8 @@ public void testComputeAndAbsentWithEvictionsFromOnHeapCache() throws Exception TieredCacheSpilloverStrategyService spilloverStrategyService = intializeTieredCacheService( onHeapCacheSize, diskCacheSize, - eventListener + eventListener, + null ); // Put values in cache more than it's size and cause evictions from onHeap. @@ -140,7 +146,8 @@ public void testComputeAndAbsentWithEvictionsFromBothTier() throws Exception { TieredCacheSpilloverStrategyService spilloverStrategyService = intializeTieredCacheService( onHeapCacheSize, diskCacheSize, - eventListener + eventListener, + null ); int numOfItems = randomIntBetween(totalSize + 1, totalSize * 3); @@ -161,7 +168,8 @@ public void testGetAndCount() throws Exception { TieredCacheSpilloverStrategyService spilloverStrategyService = intializeTieredCacheService( onHeapCacheSize, diskCacheSize, - eventListener + eventListener, + null ); int numOfItems1 = randomIntBetween(onHeapCacheSize + 1, totalSize); @@ -198,9 +206,12 @@ public void testGetAndCount() throws Exception { public void testWithDiskTierNull() throws Exception { int onHeapCacheSize = randomIntBetween(10, 30); MockTieredCacheEventListener eventListener = new MockTieredCacheEventListener(); + Function identityFunction = (String value) -> { return value; }; TieredCacheSpilloverStrategyService spilloverStrategyService = new TieredCacheSpilloverStrategyService.Builder< String, - String>().setOnHeapCachingTier(new MockOnHeapCacheTier<>(onHeapCacheSize)).setTieredCacheEventListener(eventListener).build(); + String>().setOnHeapCachingTier(new MockOnHeapCacheTier<>(onHeapCacheSize)) + .setTieredCacheEventListener(eventListener) + .build(); int numOfItems = randomIntBetween(onHeapCacheSize + 1, onHeapCacheSize * 3); for (int iter = 0; iter < numOfItems; iter++) { TieredCacheLoader tieredCacheLoader = getTieredCacheLoader(); @@ -212,6 +223,70 @@ public void testWithDiskTierNull() throws Exception { assertEquals(0, eventListener.enumMap.get(TierType.DISK).missCount.count()); } + public void testDiskTierPolicies() throws Exception { + // For policy function, allow if what it receives starts with "a" and string is even length + ArrayList> policies = new ArrayList<>(); + policies.add(new AllowFirstLetterA()); + policies.add(new AllowEvenLengths()); + + int onHeapCacheSize = 0; + int diskCacheSize = 10000; + MockTieredCacheEventListener eventListener = new MockTieredCacheEventListener(); + TieredCacheSpilloverStrategyService spilloverStrategyService = intializeTieredCacheService( + onHeapCacheSize, + diskCacheSize, + eventListener, + policies + ); + + Map keyValuePairs = new HashMap<>(); + Map expectedOutputs = new HashMap<>(); + keyValuePairs.put("key1", "abcd"); + expectedOutputs.put("key1", true); + keyValuePairs.put("key2", "abcde"); + expectedOutputs.put("key2", false); + keyValuePairs.put("key3", "bbc"); + expectedOutputs.put("key3", false); + keyValuePairs.put("key4", "ab"); + expectedOutputs.put("key4", true); + keyValuePairs.put("key5", ""); + expectedOutputs.put("key5", false); + + TieredCacheLoader loader = getTieredCacheLoaderWithKeyValueMap(keyValuePairs); + + for (String key : keyValuePairs.keySet()) { + Boolean expectedOutput = expectedOutputs.get(key); + String value = spilloverStrategyService.computeIfAbsent(key, loader); + assertEquals(keyValuePairs.get(key), value); + String result = spilloverStrategyService.get(key); + if (expectedOutput) { + // Should retrieve from disk tier if it was accepted + assertEquals(keyValuePairs.get(key), result); + } else { + // Should miss as heap tier size = 0 and the policy rejected it + assertNull(result); + } + } + } + + private static class AllowFirstLetterA implements CacheTierPolicy { + @Override + public boolean checkData(String data) { + try { + return (data.charAt(0) == 'a'); + } catch (StringIndexOutOfBoundsException e) { + return false; + } + } + } + + private static class AllowEvenLengths implements CacheTierPolicy { + @Override + public boolean checkData(String data) { + return data.length() % 2 == 0; + } + } + private TieredCacheLoader getTieredCacheLoader() { return new TieredCacheLoader() { boolean isLoaded = false; @@ -229,16 +304,41 @@ public boolean isLoaded() { }; } + private TieredCacheLoader getTieredCacheLoaderWithKeyValueMap(Map map) { + return new TieredCacheLoader() { + boolean isLoaded; + @Override + public String load(String key) throws Exception { + isLoaded = true; + return map.get(key); + } + + @Override + public boolean isLoaded() { + return isLoaded; + } + }; + } + private TieredCacheSpilloverStrategyService intializeTieredCacheService( int onHeapCacheSize, - int diksCacheSize, - TieredCacheEventListener cacheEventListener + int diskCacheSize, + TieredCacheEventListener cacheEventListener, + List> policies // If passed null, default to no policies (empty list) ) { - DiskCachingTier diskCache = new MockDiskCachingTier<>(diksCacheSize); + DiskCachingTier diskCache = new MockDiskCachingTier<>(diskCacheSize); OnHeapCachingTier openSearchOnHeapCache = new MockOnHeapCacheTier<>(onHeapCacheSize); + + List> policiesToUse = new ArrayList<>(); + if (policies != null) { + policiesToUse = policies; + } + + return new TieredCacheSpilloverStrategyService.Builder().setOnHeapCachingTier(openSearchOnHeapCache) .setOnDiskCachingTier(diskCache) .setTieredCacheEventListener(cacheEventListener) + .withPolicies(policiesToUse) .build(); } diff --git a/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java b/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java index 22d185a02d1a4..c18250bb6bec2 100644 --- a/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java +++ b/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java @@ -24,8 +24,9 @@ public class IRCKeyWriteableSerializerTests extends OpenSearchSingleNodeTestCase { public void testSerializer() throws Exception { + ClusterSettings dummyClusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); IndicesService indicesService = getInstanceFromNode(IndicesService.class); - IndicesRequestCache irc = new IndicesRequestCache(Settings.EMPTY, indicesService); + IndicesRequestCache irc = new IndicesRequestCache(Settings.EMPTY, indicesService, dummyClusterSettings); IndexService indexService = createIndex("test"); IndexShard indexShard = indexService.getShardOrNull(0); IndicesService.IndexShardCacheEntity entity = indicesService.new IndexShardCacheEntity(indexShard); diff --git a/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java b/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java index 5fbffe6906d56..c2e7b4ad26bfd 100644 --- a/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java +++ b/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java @@ -41,16 +41,26 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.Term; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.OriginalIndicesTests; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.CheckedSupplier; +import org.opensearch.common.UUIDs; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.lucene.index.OpenSearchDirectoryReader; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.io.IOUtils; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.AbstractBytesReference; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -62,6 +72,12 @@ import org.opensearch.index.cache.request.ShardRequestCache; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.shard.IndexShard; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchSingleNodeTestCase; import java.io.IOException; @@ -73,7 +89,8 @@ public class IndicesRequestCacheTests extends OpenSearchSingleNodeTestCase { public void testBasicOperationsCache() throws Exception { ShardRequestCache requestCacheStats = new ShardRequestCache(); - IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class)); + ClusterSettings dummyClusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class), dummyClusterSettings); Directory dir = newDirectory(); IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); @@ -127,7 +144,8 @@ public void testBasicOperationsCache() throws Exception { } public void testCacheDifferentReaders() throws Exception { - IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class)); + ClusterSettings dummyClusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class), dummyClusterSettings); AtomicBoolean indexShard = new AtomicBoolean(true); ShardRequestCache requestCacheStats = new ShardRequestCache(); Directory dir = newDirectory(); @@ -222,8 +240,9 @@ public void testCacheDifferentReaders() throws Exception { public void testEviction() throws Exception { final ByteSizeValue size; + ClusterSettings dummyClusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); { - IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class)); + IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class), dummyClusterSettings); AtomicBoolean indexShard = new AtomicBoolean(true); ShardRequestCache requestCacheStats = new ShardRequestCache(); Directory dir = newDirectory(); @@ -250,7 +269,8 @@ public void testEviction() throws Exception { } IndicesRequestCache cache = new IndicesRequestCache( Settings.builder().put(IndicesRequestCache.INDICES_CACHE_QUERY_SIZE.getKey(), size.getBytes() + 1 + "b").build(), - getInstanceFromNode(IndicesService.class) + getInstanceFromNode(IndicesService.class), + dummyClusterSettings ); AtomicBoolean indexShard = new AtomicBoolean(true); ShardRequestCache requestCacheStats = new ShardRequestCache(); @@ -287,7 +307,8 @@ public void testEviction() throws Exception { } public void testClearAllEntityIdentity() throws Exception { - IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class)); + ClusterSettings dummyClusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class), dummyClusterSettings); AtomicBoolean indexShard = new AtomicBoolean(true); ShardRequestCache requestCacheStats = new ShardRequestCache(); @@ -372,7 +393,8 @@ public BytesReference get() { public void testInvalidate() throws Exception { ShardRequestCache requestCacheStats = new ShardRequestCache(); - IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class)); + ClusterSettings dummyClusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY, getInstanceFromNode(IndicesService.class), dummyClusterSettings); Directory dir = newDirectory(); IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); @@ -468,6 +490,40 @@ public void testEqualsKey() throws IOException { assertNotEquals(key1, key5); } + private static BytesReference getQSRBytesReference(long tookTimeNanos) throws IOException { + // unfortunately no good way to separate this out from DiskTierTookTimePolicyTests.getQSR() :( + ShardId shardId = new ShardId("index", "uuid", randomInt()); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()); + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + OriginalIndicesTests.randomOriginalIndices(), + searchRequest, + shardId, + 1, + new AliasFilter(null, Strings.EMPTY_ARRAY), + 1.0f, + randomNonNegativeLong(), + null, + new String[0] + ); + ShardSearchContextId id = new ShardSearchContextId(UUIDs.base64UUID(), randomLong()); + QuerySearchResult result = new QuerySearchResult( + id, + new SearchShardTarget("node", shardId, null, OriginalIndices.NONE), + shardSearchRequest + ); + TopDocs topDocs = new TopDocs(new TotalHits(randomLongBetween(0, Long.MAX_VALUE), TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + result.topDocs(new TopDocsAndMaxScore(topDocs, randomBoolean() ? Float.NaN : randomFloat()), new DocValueFormat[0]); + + result.setTookTimeNanos(tookTimeNanos); + + BytesStreamOutput out = new BytesStreamOutput(); + // it appears to need a boolean and then a ShardSearchContextId written to the stream before the QSR in order to deserialize? + out.writeBoolean(false); + id.writeTo(out); + result.writeToNoId(out); + return out.bytes(); + } + private class TestBytesReference extends AbstractBytesReference { int dummyValue; diff --git a/server/src/test/java/org/opensearch/search/SearchServiceTests.java b/server/src/test/java/org/opensearch/search/SearchServiceTests.java index 7c84078af080e..caef2a450044f 100644 --- a/server/src/test/java/org/opensearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/opensearch/search/SearchServiceTests.java @@ -54,6 +54,7 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.WriteRequest; import org.opensearch.common.UUIDs; +import org.opensearch.common.recycler.Recycler; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.SettingsException; import org.opensearch.common.unit.TimeValue; @@ -121,6 +122,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; @@ -823,6 +825,116 @@ public Scroll scroll() { } } + public void testQuerySearchResultTookTime() throws Exception { + // I wasn't able to introduce a delay in these tests as everything between creation and usage of the QuerySearchResult object + // happen in a single line - we would have to modify QueryPhase.execute() to take a delay parameter + // However this was tested manually + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())); + + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 2, // must have >1 shards for executeQueryPhase to return the QuerySearchResult + new AliasFilter(null, Strings.EMPTY_ARRAY), + 1.0f, + -1, + null, + null + ); + + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + service.executeQueryPhase(request, randomBoolean(), task, new ActionListener() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + assertEquals(QuerySearchResult.class, searchPhaseResult.getClass()); // 2+ shards -> QuerySearchResult returned + QuerySearchResult qsr = (QuerySearchResult) searchPhaseResult; + assertTrue(qsr.getTookTimeNanos() > 0); // Above zero means it's been set at some point + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + }); + } + public void testQuerySearchResultTookTimeCacheableRequest() throws Exception { + // Test 2 identical cacheable requests and assert both have the same tookTime + // Similarly, no delay could be added + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + + searchRequest.source(searchSourceBuilder); + searchSourceBuilder.scriptField( + "field" + 0, + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap()) + ); + searchSourceBuilder.size(0); // from testIgnoreScriptfieldIfSizeZero + + String[] dummyRoutings = new String[]{}; + OriginalIndices dummyOriginalIndices = new OriginalIndices(new String[]{"index'"}, IndicesOptions.LENIENT_EXPAND_OPEN); + + ShardSearchRequest request = new ShardSearchRequest( + dummyOriginalIndices, + searchRequest, + indexShard.shardId(), + 2, // must have >1 shards for executeQueryPhase to return the QuerySearchResult + new AliasFilter(null, Strings.EMPTY_ARRAY), + 1.0f, + 0L, + // if nowInMillis is negative, it fails when trying to write the shardSearchRequest to cache as it uses WriteVLong which only takes positive longs + null, + dummyRoutings // similar for routings + ); + + final CompletableFuture firstResult = new CompletableFuture<>(); + final CompletableFuture secondResult = new CompletableFuture<>(); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + service.executeQueryPhase(request, randomBoolean(), task, new ActionListener() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + assertEquals(QuerySearchResult.class, searchPhaseResult.getClass()); // 2+ shards -> QuerySearchResult returned + QuerySearchResult qsr = (QuerySearchResult) searchPhaseResult; + firstResult.complete(qsr.getTookTimeNanos()); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + }); + + service.executeQueryPhase(request, randomBoolean(), task, new ActionListener() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + assertEquals(QuerySearchResult.class, searchPhaseResult.getClass()); // 2+ shards -> QuerySearchResult returned + QuerySearchResult qsr = (QuerySearchResult) searchPhaseResult; + secondResult.complete(qsr.getTookTimeNanos()); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + }); + + long firstResultVal = firstResult.get(); + long secondResultVal = secondResult.get(); + assertEquals(firstResultVal, secondResultVal); + assertTrue(firstResultVal > 0); + } + public void testCanMatch() throws Exception { createIndex("index"); final SearchService service = getInstanceFromNode(SearchService.class); @@ -1010,6 +1122,7 @@ public void onFailure(Exception e) { } } }); + latch.await(); } diff --git a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java index 39126a607f968..ef30cea39be5c 100644 --- a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java @@ -85,9 +85,14 @@ import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchShardTask; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -103,9 +108,11 @@ import org.opensearch.lucene.queries.MinDocQuery; import org.opensearch.search.DocValueFormat; import org.opensearch.search.collapse.CollapseBuilder; +import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.sort.SortAndFormats; import org.opensearch.test.TestSearchContext; import org.opensearch.threadpool.ThreadPool; @@ -115,6 +122,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.LinkedList; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -1145,6 +1153,114 @@ public void testQueryTimeoutChecker() throws Exception { createTimeoutCheckerThenWaitThenRun(timeCacheLifespan / 4, timeCacheLifespan / 2 + timeTolerance, false, true); } + public void testQuerySearchResultTookTime() throws IOException { + int sleepMillis = randomIntBetween(10, 100); // between 0.01 and 0.1 sec + DelayedQueryPhaseSearcher delayedQueryPhaseSearcher = new DelayedQueryPhaseSearcher(sleepMillis); + + // we need to test queryPhase.execute(), not executeInternal(), since that's what the timer wraps around + // for that we must set up a searchContext with more functionality than the TestSearchContext, + // which requires a bit of complexity with test classes + + Directory dir = newDirectory(); + final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + for (int i = 0; i < 10; i++) { + doc.add(new StringField("foo", Integer.toString(i), Store.NO)); + } + w.addDocument(doc); + w.close(); + IndexReader reader = DirectoryReader.open(dir); + + QueryShardContext queryShardContext = mock(QueryShardContext.class); + when(queryShardContext.fieldMapper("user")).thenReturn( + new NumberFieldType("user", NumberType.INTEGER, true, false, true, false, null, Collections.emptyMap()) + ); + + Index index = new Index("IndexName", "UUID"); + ShardId shardId = new ShardId(index, 0); + long nowInMillis = System.currentTimeMillis(); + String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.allowPartialSearchResults(randomBoolean()); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + shardId, + 1, + AliasFilter.EMPTY, + 1f, + nowInMillis, + clusterAlias, + Strings.EMPTY_ARRAY + ); + TestSearchContextWithRequest searchContext = new TestSearchContextWithRequest( + queryShardContext, + indexShard, + newEarlyTerminationContextSearcher(reader, 0, executor), + request + ); + + QueryPhase queryPhase = new QueryPhase(delayedQueryPhaseSearcher); + queryPhase.execute(searchContext); + Long tookTime = searchContext.queryResult().getTookTimeNanos(); + assertTrue(tookTime >= (long) sleepMillis * 1000000); + reader.close(); + dir.close(); + } + + private class TestSearchContextWithRequest extends TestSearchContext { + ShardSearchRequest request; + Query query; + + public TestSearchContextWithRequest( + QueryShardContext queryShardContext, + IndexShard indexShard, + ContextIndexSearcher searcher, + ShardSearchRequest request + ) { + super(queryShardContext, indexShard, searcher); + this.request = request; + this.query = new TermQuery(new Term("foo", "bar")); + } + + @Override + public ShardSearchRequest request() { + return request; + } + + @Override + public Query query() { + return this.query; + } + } + + private class DelayedQueryPhaseSearcher extends QueryPhase.DefaultQueryPhaseSearcher implements QueryPhaseSearcher { + // add delay into searchWith + private final int sleepMillis; + + public DelayedQueryPhaseSearcher(int sleepMillis) { + super(); + this.sleepMillis = sleepMillis; + } + + @Override + public boolean searchWith( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + try { + Thread.sleep(sleepMillis); + } catch (Exception ignored) {} + return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } + } + private void createTimeoutCheckerThenWaitThenRun( long timeout, long sleepAfterCreation, diff --git a/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java b/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java index 41e4e1ae45a73..1b8fc9d7dbc5c 100644 --- a/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java +++ b/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java @@ -56,6 +56,8 @@ import org.opensearch.search.suggest.SuggestTests; import org.opensearch.test.OpenSearchTestCase; +import java.util.HashMap; + import static java.util.Collections.emptyList; public class QuerySearchResultTests extends OpenSearchTestCase { @@ -99,25 +101,36 @@ private static QuerySearchResult createTestInstance() throws Exception { if (randomBoolean()) { result.aggregations(InternalAggregationsTests.createTestInstance()); } + assertNull(result.getTookTimeNanos()); return result; } public void testSerialization() throws Exception { - QuerySearchResult querySearchResult = createTestInstance(); - QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, QuerySearchResult::new); - assertEquals(querySearchResult.getContextId().getId(), deserialized.getContextId().getId()); - assertNull(deserialized.getSearchShardTarget()); - assertEquals(querySearchResult.topDocs().maxScore, deserialized.topDocs().maxScore, 0f); - assertEquals(querySearchResult.topDocs().topDocs.totalHits, deserialized.topDocs().topDocs.totalHits); - assertEquals(querySearchResult.from(), deserialized.from()); - assertEquals(querySearchResult.size(), deserialized.size()); - assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs()); - if (deserialized.hasAggs()) { - Aggregations aggs = querySearchResult.consumeAggs().expand(); - Aggregations deserializedAggs = deserialized.consumeAggs().expand(); - assertEquals(aggs.asList(), deserializedAggs.asList()); + HashMap expectedValues = new HashMap<>(); // map contains whether to set took time, and if so, to what value + expectedValues.put(false, null); + expectedValues.put(true, 1000L); + for (Boolean doSetTookTime : expectedValues.keySet()) { + QuerySearchResult querySearchResult = createTestInstance(); + if (doSetTookTime) { + querySearchResult.setTookTimeNanos(expectedValues.get(doSetTookTime)); + } + QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, QuerySearchResult::new); + assertEquals(querySearchResult.getContextId().getId(), deserialized.getContextId().getId()); + assertNull(deserialized.getSearchShardTarget()); + assertEquals(querySearchResult.topDocs().maxScore, deserialized.topDocs().maxScore, 0f); + assertEquals(querySearchResult.topDocs().topDocs.totalHits, deserialized.topDocs().topDocs.totalHits); + assertEquals(querySearchResult.from(), deserialized.from()); + assertEquals(querySearchResult.size(), deserialized.size()); + assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs()); + if (deserialized.hasAggs()) { + Aggregations aggs = querySearchResult.consumeAggs().expand(); + Aggregations deserializedAggs = deserialized.consumeAggs().expand(); + assertEquals(aggs.asList(), deserializedAggs.asList()); + } + assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly()); + assertEquals(querySearchResult.getTookTimeNanos(), deserialized.getTookTimeNanos()); + assertEquals(expectedValues.get(doSetTookTime), querySearchResult.getTookTimeNanos()); } - assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly()); } public void testNullResponse() throws Exception {