From b25518a88398df3619a5dd34e82e51fb8e612827 Mon Sep 17 00:00:00 2001 From: Armin Date: Sat, 26 Apr 2025 23:06:02 +0200 Subject: [PATCH 1/2] wee --- .../common/bytes/CompositeBytesReference.java | 4 + .../bytes/ReleasableBytesReference.java | 138 ++++++++++++++++++ .../transport/InboundAggregator.java | 27 ++-- .../transport/InboundMessage.java | 27 ++-- .../transport/InboundHandlerTests.java | 17 ++- 5 files changed, 180 insertions(+), 33 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/bytes/CompositeBytesReference.java b/server/src/main/java/org/elasticsearch/common/bytes/CompositeBytesReference.java index 537082fedd602..3d4e21e6cc8ec 100644 --- a/server/src/main/java/org/elasticsearch/common/bytes/CompositeBytesReference.java +++ b/server/src/main/java/org/elasticsearch/common/bytes/CompositeBytesReference.java @@ -168,6 +168,10 @@ public BytesReference slice(int from, int length) { return CompositeBytesReference.ofMultiple(inSlice); } + public BytesReference[] components() { + return references; + } + private int getOffsetIndex(int offset) { final int i = Arrays.binarySearch(offsets, offset); return i < 0 ? (-(i + 1)) - 1 : i; diff --git a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java index 4343892428c9a..86d6908d7e58d 100644 --- a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java +++ b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java @@ -157,6 +157,144 @@ public long ramBytesUsed() { return delegate.ramBytesUsed(); } + public static StreamInput consumingStreamInput(ReleasableBytesReference... references) throws IOException { + final BytesReference bytesReference; + final RefCounted[] refs = new RefCounted[references.length]; + if (references.length == 1) { + final var ref = references[0]; + bytesReference = ref; + refs[0] = ref.refCounted; + } else { + bytesReference = CompositeBytesReference.of(references); + for (int i = 0; i < references.length; i++) { + refs[i] = references[i].refCounted; + } + } + return new BytesReferenceStreamInput(bytesReference) { + private ReleasableBytesReference retainAndSkip(int len) throws IOException { + if (len == 0) { + return ReleasableBytesReference.empty(); + } + + int offset = offset(); + skip(len); + // move the stream manually since creating the slice didn't move it + if (bytesReference instanceof ReleasableBytesReference releasable) { + ReleasableBytesReference res = releasable.retainedSlice(offset, len); + if (markEnd == 0 && available() == 0) { + close(); + } + return res; + } + assert bytesReference instanceof CompositeBytesReference; + final CompositeBytesReference composite = (CompositeBytesReference) bytesReference; + // instead of reading the bytes from a stream we just create a slice of the underlying bytes + final BytesReference result = composite.slice(offset, len); + if (result instanceof ReleasableBytesReference releasable) { + return releasable.retain(); + } + assert result instanceof CompositeBytesReference; + var compositeSlice = (CompositeBytesReference) result; + var components = compositeSlice.components(); + final RefCounted[] refCounteds = new RefCounted[components.length]; + for (int i = 0; i < components.length; i++) { + refCounteds[i] = ((ReleasableBytesReference) components[i]).retain(); + } + if (markEnd == 0) { + maybeDiscardReadBytes(composite.components()); + } + return new ReleasableBytesReference(compositeSlice, () -> { + for (int i = 0; i < refCounteds.length; i++) { + refCounteds[i].decRef(); + refCounteds[i] = null; + } + }); + } + + private void maybeDiscardReadBytes(BytesReference[] components) { + int offset = offset(); + int p = 0; + for (int i = 0; i < components.length; i++) { + p += components[i].length(); + if (p >= offset) { + return; + } + var r = refs[i]; + if (r != null) { + r.decRef(); + refs[i] = null; + } + } + } + + @Override + public ReleasableBytesReference readReleasableBytesReference() throws IOException { + final int len = readVInt(); + return retainAndSkip(len); + } + + @Override + public ReleasableBytesReference readReleasableBytesReference(int len) throws IOException { + return retainAndSkip(len); + } + + @Override + public ReleasableBytesReference readAllToReleasableBytesReference() throws IOException { + return retainAndSkip(bytesReference.length() - offset()); + } + + @Override + public boolean supportReadAllToReleasableBytesReference() { + return true; + } + + @Override + public void close() { + for (int i = 0; i < refs.length; i++) { + RefCounted ref = refs[i]; + if (ref != null) { + refs[i] = null; + ref.decRef(); + } + } + } + + @Override + public int read(byte[] b, int bOffset, int len) throws IOException { + int res = super.read(b, bOffset, len); + if (markEnd == 0) { + tryDiscard(); + } + return res; + } + + private void tryDiscard() { + if (bytesReference instanceof CompositeBytesReference c) { + maybeDiscardReadBytes(c.components()); + } else if (available() == 0) { + close(); + } + } + + @Override + public int read() throws IOException { + int res = super.read(); + if (res == -1 && markEnd == 0) { + close(); + } + return res; + } + + private int markEnd = 0; + + @Override + public void mark(int readLimit) { + super.mark(readLimit); + markEnd = offset() + readLimit; + } + }; + } + @Override public StreamInput streamInput() throws IOException { assert hasReferences(); diff --git a/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java b/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java index 9e4b1b427d080..0aa9416edc7df 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java @@ -11,8 +11,6 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -95,19 +93,30 @@ public void aggregate(ReleasableBytesReference content) { public InboundMessage finishAggregation() throws IOException { ensureOpen(); - final ReleasableBytesReference releasableContent; + final ReleasableBytesReference[] releasableContent; + final int len; if (isFirstContent()) { - releasableContent = ReleasableBytesReference.empty(); + releasableContent = new ReleasableBytesReference[] { ReleasableBytesReference.empty() }; + len = 0; } else if (contentAggregation == null) { - releasableContent = firstContent; + releasableContent = new ReleasableBytesReference[] { firstContent }; + len = firstContent.length(); } else { - final ReleasableBytesReference[] references = contentAggregation.toArray(new ReleasableBytesReference[0]); - final BytesReference content = CompositeBytesReference.of(references); - releasableContent = new ReleasableBytesReference(content, () -> Releasables.close(references)); + releasableContent = contentAggregation.toArray(new ReleasableBytesReference[0]); + int l = 0; + for (ReleasableBytesReference releasableBytesReference : releasableContent) { + l += releasableBytesReference.length(); + } + len = l; } final BreakerControl breakerControl = new BreakerControl(circuitBreaker); - final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl); + final InboundMessage aggregated = new InboundMessage( + currentHeader, + ReleasableBytesReference.consumingStreamInput(releasableContent), + len, + breakerControl + ); boolean success = false; try { if (aggregated.getHeader().needsToReadVariableHeader()) { diff --git a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java index ecf8f4bbbf412..3c33b407f3d1a 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java @@ -10,7 +10,6 @@ package org.elasticsearch.transport; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Releasable; @@ -23,7 +22,7 @@ public class InboundMessage implements Releasable { private final Header header; - private final ReleasableBytesReference content; + private final int contentLength; private final Exception exception; private final boolean isPing; private Releasable breakerRelease; @@ -42,17 +41,19 @@ public class InboundMessage implements Releasable { } } - public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { + public InboundMessage(Header header, StreamInput streamInput, int contentLength, Releasable breakerRelease) { this.header = header; - this.content = content; + this.streamInput = streamInput; + streamInput.setTransportVersion(header.getVersion()); this.breakerRelease = breakerRelease; this.exception = null; this.isPing = false; + this.contentLength = contentLength; } public InboundMessage(Header header, Exception exception) { this.header = header; - this.content = null; + this.contentLength = 0; this.breakerRelease = null; this.exception = exception; this.isPing = false; @@ -60,7 +61,7 @@ public InboundMessage(Header header, Exception exception) { public InboundMessage(Header header, boolean isPing) { this.header = header; - this.content = null; + this.contentLength = 0; this.breakerRelease = null; this.exception = null; this.isPing = isPing; @@ -71,11 +72,7 @@ public Header getHeader() { } public int getContentLength() { - if (content == null) { - return 0; - } else { - return content.length(); - } + return contentLength; } public Exception getException() { @@ -97,12 +94,6 @@ public Releasable takeBreakerReleaseControl() { } public StreamInput openOrGetStreamInput() throws IOException { - assert isPing == false && content != null; - assert (boolean) CLOSED.getAcquire(this) == false; - if (streamInput == null) { - streamInput = content.streamInput(); - streamInput.setTransportVersion(header.getVersion()); - } return streamInput; } @@ -117,7 +108,7 @@ public void close() { return; } try { - IOUtils.close(streamInput, content, breakerRelease); + IOUtils.close(streamInput, breakerRelease); } catch (Exception e) { assert false : e; throw new ElasticsearchException(e); diff --git a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java index a683c2332c451..301e83d08cacd 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java @@ -16,7 +16,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.InputStreamStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -190,7 +189,7 @@ public TestResponse read(StreamInput in) throws IOException { TransportStatus.setRequest((byte) 0), TransportVersion.current() ); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + InboundMessage requestMessage = new InboundMessage(requestHeader, requestContent.streamInput(), requestContent.length(), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -210,7 +209,12 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(TcpHeader.HEADER_SIZE, fullResponseBytes.length() - TcpHeader.HEADER_SIZE); Header responseHeader = new Header(fullRequestBytes.length() - 6, requestId, responseStatus, TransportVersion.current()); - InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); + InboundMessage responseMessage = new InboundMessage( + responseHeader, + responseContent.streamInput(), + responseContent.length(), + () -> {} + ); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -298,7 +302,8 @@ public void testLogsSlowInboundProcessing() throws Exception { } final InboundMessage requestMessage = new InboundMessage( requestHeader, - ReleasableBytesReference.wrap(byteData.bytes()), + byteData.bytes().streamInput(), + byteData.size(), () -> safeSleep(TimeValue.timeValueSeconds(1)) ); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; @@ -327,14 +332,14 @@ public void onResponseReceived(long requestId, Transport.ResponseContext context safeSleep(TimeValue.timeValueSeconds(1)); } }); - handler.inboundMessage(channel, new InboundMessage(responseHeader, ReleasableBytesReference.empty(), () -> {})); + handler.inboundMessage(channel, new InboundMessage(responseHeader, BytesArray.EMPTY.streamInput(), 0, () -> {})); mockLog.assertAllExpectationsMatched(); } } private static InboundMessage unreadableInboundHandshake(TransportVersion remoteVersion, Header requestHeader) { - return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { + return new InboundMessage(requestHeader, BytesArray.EMPTY.streamInput(), 0, () -> {}) { @Override public StreamInput openOrGetStreamInput() { final StreamInput streamInput = new InputStreamStreamInput(new InputStream() { From 50fe2b7bc56bcf085755fb6fad725eeee4a61ce9 Mon Sep 17 00:00:00 2001 From: Armin Date: Sat, 26 Apr 2025 23:38:04 +0200 Subject: [PATCH 2/2] Release transport messages incrementally while reading them Still a few small steps left to clean this up but even in this form this solution essentially up to halves the heap used for handling large bulk shard requests on non-coordinating data nodes (this is just one example, there's a couple spots where this saves a lot of memory). Also, this could be extended to be a little smarter around compression easily, allowing for potential order of magnitude savings around indexing if we lazy deserialize individual docs or play similar tricks. --- .../bytes/ReleasableBytesReference.java | 28 +++----------- .../common/io/stream/FilterStreamInput.java | 5 +++ .../common/io/stream/StreamInput.java | 6 ++- .../transport/TransportLogger.java | 37 +++++++------------ 4 files changed, 29 insertions(+), 47 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java index 86d6908d7e58d..92fee6cb3d186 100644 --- a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java +++ b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java @@ -259,30 +259,14 @@ public void close() { } } - @Override - public int read(byte[] b, int bOffset, int len) throws IOException { - int res = super.read(b, bOffset, len); + public void tryDiscard() { if (markEnd == 0) { - tryDiscard(); - } - return res; - } - - private void tryDiscard() { - if (bytesReference instanceof CompositeBytesReference c) { - maybeDiscardReadBytes(c.components()); - } else if (available() == 0) { - close(); - } - } - - @Override - public int read() throws IOException { - int res = super.read(); - if (res == -1 && markEnd == 0) { - close(); + if (bytesReference instanceof CompositeBytesReference c) { + maybeDiscardReadBytes(c.components()); + } else if (available() == 0) { + close(); + } } - return res; } private int markEnd = 0; diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java index 34de4faa69f62..3a66cb7e5aaa3 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java @@ -108,6 +108,11 @@ public int read(byte[] b, int off, int len) throws IOException { return delegate.read(b, off, len); } + @Override + public void tryDiscard() { + delegate.tryDiscard(); + } + @Override public void close() throws IOException { delegate.close(); diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java index 8b54c7a78907c..58ada13ccc518 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java @@ -109,13 +109,17 @@ public void setTransportVersion(TransportVersion version) { @Override public abstract int read(byte[] b, int off, int len) throws IOException; + public void tryDiscard() {} + /** * Reads a bytes reference from this stream, copying any bytes read to a new {@code byte[]}. Use {@link #readReleasableBytesReference()} * when reading large bytes references where possible top avoid needless allocations and copying. */ public BytesReference readBytesReference() throws IOException { int length = readArraySize(); - return readBytesReference(length); + var res = readBytesReference(length); + tryDiscard(); + return res; } /** diff --git a/server/src/main/java/org/elasticsearch/transport/TransportLogger.java b/server/src/main/java/org/elasticsearch/transport/TransportLogger.java index 6cad97acab7ba..b3cb5e47bd499 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportLogger.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportLogger.java @@ -118,35 +118,24 @@ private static String format(TcpChannel channel, InboundMessage message, String if (message.isPing()) { sb.append(" [ping]").append(' ').append(event).append(": ").append(6).append('B'); } else { - boolean success = false; Header header = message.getHeader(); int networkMessageSize = header.getNetworkMessageSize(); int messageLengthWithHeader = HEADER_SIZE + networkMessageSize; - StreamInput streamInput = message.openOrGetStreamInput(); - try { - final long requestId = header.getRequestId(); - final boolean isRequest = header.isRequest(); - final String type = isRequest ? "request" : "response"; - final String version = header.getVersion().toString(); - sb.append(" [length: ").append(messageLengthWithHeader); - sb.append(", request id: ").append(requestId); - sb.append(", type: ").append(type); - sb.append(", version: ").append(version); + final long requestId = header.getRequestId(); + final boolean isRequest = header.isRequest(); + final String type = isRequest ? "request" : "response"; + final String version = header.getVersion().toString(); + sb.append(" [length: ").append(messageLengthWithHeader); + sb.append(", request id: ").append(requestId); + sb.append(", type: ").append(type); + sb.append(", version: ").append(version); - // TODO: Maybe Fix for BWC - if (header.needsToReadVariableHeader() == false && isRequest) { - sb.append(", action: ").append(header.getActionName()); - } - sb.append(']'); - sb.append(' ').append(event).append(": ").append(messageLengthWithHeader).append('B'); - success = true; - } finally { - if (success) { - IOUtils.close(streamInput); - } else { - IOUtils.closeWhileHandlingException(streamInput); - } + // TODO: Maybe Fix for BWC + if (header.needsToReadVariableHeader() == false && isRequest) { + sb.append(", action: ").append(header.getActionName()); } + sb.append(']'); + sb.append(' ').append(event).append(": ").append(messageLengthWithHeader).append('B'); } return sb.toString(); }