14
14
import io .netty .channel .ChannelPromise ;
15
15
import io .netty .channel .embedded .EmbeddedChannel ;
16
16
17
+ import org .elasticsearch .common .bytes .BytesArray ;
18
+ import org .elasticsearch .common .bytes .BytesReference ;
19
+ import org .elasticsearch .common .bytes .CompositeBytesReference ;
17
20
import org .elasticsearch .common .settings .Settings ;
18
21
import org .elasticsearch .common .util .concurrent .ThreadContext ;
19
22
import org .elasticsearch .test .ESTestCase ;
28
31
import static org .hamcrest .Matchers .instanceOf ;
29
32
import static org .hamcrest .Matchers .lessThan ;
30
33
import static org .hamcrest .Matchers .lessThanOrEqualTo ;
34
+ import static org .hamcrest .Matchers .oneOf ;
31
35
32
36
public class Netty4WriteThrottlingHandlerTests extends ESTestCase {
33
37
@@ -56,42 +60,76 @@ public void testThrottlesLargeMessage() throws ExecutionException, InterruptedEx
56
60
assertThat (writeableBytes , lessThan (Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE ));
57
61
final int fullSizeChunks = randomIntBetween (2 , 10 );
58
62
final int extraChunkSize = randomIntBetween (0 , 10 );
59
- final ByteBuf message = Unpooled . wrappedBuffer (
60
- randomByteArrayOfLength ( Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE * fullSizeChunks + extraChunkSize )
63
+ final byte [] messageBytes = randomByteArrayOfLength (
64
+ Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE * fullSizeChunks + extraChunkSize
61
65
);
66
+ final Object message = wrapAsNettyOrEsBuffer (messageBytes );
62
67
final ChannelPromise promise = embeddedChannel .newPromise ();
63
68
transportGroup .getLowLevelGroup ().submit (() -> embeddedChannel .write (message , promise )).get ();
64
69
assertThat (seen , hasSize (1 ));
65
- assertEquals ( message . slice ( 0 , Netty4WriteThrottlingHandler . MAX_BYTES_PER_WRITE ), seen . get ( 0 ) );
70
+ assertSliceEquals ( seen . get ( 0 ), message , 0 , Netty4WriteThrottlingHandler . MAX_BYTES_PER_WRITE );
66
71
assertFalse (promise .isDone ());
67
72
transportGroup .getLowLevelGroup ().submit (embeddedChannel ::flush ).get ();
68
73
assertTrue (promise .isDone ());
69
74
assertThat (seen , hasSize (fullSizeChunks + (extraChunkSize == 0 ? 0 : 1 )));
70
75
assertTrue (capturingHandler .didWriteAfterThrottled );
71
76
if (extraChunkSize != 0 ) {
72
- assertEquals (
73
- message .slice (Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE * fullSizeChunks , extraChunkSize ),
74
- seen .get (seen .size () - 1 )
77
+ assertSliceEquals (
78
+ seen .get (seen .size () - 1 ),
79
+ message ,
80
+ Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE * fullSizeChunks ,
81
+ extraChunkSize
75
82
);
76
83
}
77
84
}
78
85
79
- public void testPassesSmallMessageDirectly () throws ExecutionException , InterruptedException {
86
+ public void testThrottleLargeCompositeMessage () throws ExecutionException , InterruptedException {
80
87
final List <ByteBuf > seen = new CopyOnWriteArrayList <>();
81
88
final CapturingHandler capturingHandler = new CapturingHandler (seen );
82
89
final EmbeddedChannel embeddedChannel = new EmbeddedChannel (
83
90
capturingHandler ,
84
91
new Netty4WriteThrottlingHandler (new ThreadContext (Settings .EMPTY ))
85
92
);
93
+ // we assume that the channel outbound buffer is smaller than Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE
86
94
final int writeableBytes = Math .toIntExact (embeddedChannel .bytesBeforeUnwritable ());
87
95
assertThat (writeableBytes , lessThan (Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE ));
88
- final ByteBuf message = Unpooled .wrappedBuffer (
89
- randomByteArrayOfLength (randomIntBetween (0 , Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE ))
96
+ final int fullSizeChunks = randomIntBetween (2 , 10 );
97
+ final int extraChunkSize = randomIntBetween (0 , 10 );
98
+ final byte [] messageBytes = randomByteArrayOfLength (
99
+ Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE * fullSizeChunks + extraChunkSize
100
+ );
101
+ int splitOffset = randomIntBetween (0 , messageBytes .length );
102
+ final BytesReference message = CompositeBytesReference .of (
103
+ new BytesArray (messageBytes , 0 , splitOffset ),
104
+ new BytesArray (messageBytes , splitOffset , messageBytes .length - splitOffset )
105
+ );
106
+ final ChannelPromise promise = embeddedChannel .newPromise ();
107
+ transportGroup .getLowLevelGroup ().submit (() -> embeddedChannel .write (message , promise )).get ();
108
+ assertThat (seen , hasSize (oneOf (1 , 2 )));
109
+ assertSliceEquals (seen .get (0 ), message , 0 , seen .get (0 ).readableBytes ());
110
+ assertFalse (promise .isDone ());
111
+ transportGroup .getLowLevelGroup ().submit (embeddedChannel ::flush ).get ();
112
+ assertTrue (promise .isDone ());
113
+ assertThat (seen , hasSize (oneOf (fullSizeChunks , fullSizeChunks + 1 )));
114
+ assertTrue (capturingHandler .didWriteAfterThrottled );
115
+ assertBufferEquals (Unpooled .compositeBuffer ().addComponents (true , seen ), message );
116
+ }
117
+
118
+ public void testPassesSmallMessageDirectly () throws ExecutionException , InterruptedException {
119
+ final List <ByteBuf > seen = new CopyOnWriteArrayList <>();
120
+ final CapturingHandler capturingHandler = new CapturingHandler (seen );
121
+ final EmbeddedChannel embeddedChannel = new EmbeddedChannel (
122
+ capturingHandler ,
123
+ new Netty4WriteThrottlingHandler (new ThreadContext (Settings .EMPTY ))
90
124
);
125
+ final int writeableBytes = Math .toIntExact (embeddedChannel .bytesBeforeUnwritable ());
126
+ assertThat (writeableBytes , lessThan (Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE ));
127
+ final byte [] messageBytes = randomByteArrayOfLength (randomIntBetween (0 , Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE ));
128
+ final Object message = wrapAsNettyOrEsBuffer (messageBytes );
91
129
final ChannelPromise promise = embeddedChannel .newPromise ();
92
130
transportGroup .getLowLevelGroup ().submit (() -> embeddedChannel .write (message , promise )).get ();
93
131
assertThat (seen , hasSize (1 )); // first message should be passed through straight away
94
- assertSame ( message , seen .get (0 ));
132
+ assertBufferEquals ( seen .get (0 ), message );
95
133
assertFalse (promise .isDone ());
96
134
transportGroup .getLowLevelGroup ().submit (embeddedChannel ::flush ).get ();
97
135
assertTrue (promise .isDone ());
@@ -107,13 +145,14 @@ public void testThrottlesOnUnwritable() throws ExecutionException, InterruptedEx
107
145
);
108
146
final int writeableBytes = Math .toIntExact (embeddedChannel .bytesBeforeUnwritable ());
109
147
assertThat (writeableBytes , lessThan (Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE ));
110
- final ByteBuf message = Unpooled .wrappedBuffer (randomByteArrayOfLength (writeableBytes + randomIntBetween (0 , 10 )));
148
+ final byte [] messageBytes = randomByteArrayOfLength (writeableBytes + randomIntBetween (0 , 10 ));
149
+ final Object message = wrapAsNettyOrEsBuffer (messageBytes );
111
150
final ChannelPromise promise = embeddedChannel .newPromise ();
112
151
transportGroup .getLowLevelGroup ().submit (() -> embeddedChannel .write (message , promise )).get ();
113
152
assertThat (seen , hasSize (1 )); // first message should be passed through straight away
114
- assertSame ( message , seen .get (0 ));
153
+ assertBufferEquals ( seen .get (0 ), message );
115
154
assertFalse (promise .isDone ());
116
- final ByteBuf messageToQueue = Unpooled . wrappedBuffer (
155
+ final Object messageToQueue = wrapAsNettyOrEsBuffer (
117
156
randomByteArrayOfLength (randomIntBetween (0 , Netty4WriteThrottlingHandler .MAX_BYTES_PER_WRITE ))
118
157
);
119
158
final ChannelPromise promiseForQueued = embeddedChannel .newPromise ();
@@ -126,6 +165,31 @@ public void testThrottlesOnUnwritable() throws ExecutionException, InterruptedEx
126
165
assertTrue (promiseForQueued .isDone ());
127
166
}
128
167
168
+ private static void assertBufferEquals (ByteBuf expected , Object message ) {
169
+ if (message instanceof ByteBuf buf ) {
170
+ assertSame (expected , buf );
171
+ } else {
172
+ assertEquals (expected , Netty4Utils .toByteBuf (asInstanceOf (BytesReference .class , message )));
173
+ }
174
+ }
175
+
176
+ private static void assertSliceEquals (ByteBuf expected , Object message , int index , int length ) {
177
+ assertEquals (
178
+ (message instanceof ByteBuf buf ? buf : Netty4Utils .toByteBuf (asInstanceOf (BytesReference .class , message ))).slice (
179
+ index ,
180
+ length
181
+ ),
182
+ expected
183
+ );
184
+ }
185
+
186
+ private static Object wrapAsNettyOrEsBuffer (byte [] messageBytes ) {
187
+ if (randomBoolean ()) {
188
+ return Unpooled .wrappedBuffer (messageBytes );
189
+ }
190
+ return new BytesArray (messageBytes );
191
+ }
192
+
129
193
private static class CapturingHandler extends ChannelOutboundHandlerAdapter {
130
194
private final List <ByteBuf > seen ;
131
195
0 commit comments