Skip to content

Commit 8be906d

Browse files
SuficioLucioFranco
andauthored
Handle stream error correctly (#2199)
Co-authored-by: Lucio Franco <luciofranco14@gmail.com>
1 parent 7b2984c commit 8be906d

File tree

2 files changed

+94
-6
lines changed

2 files changed

+94
-6
lines changed

tests/integration_tests/tests/status.rs

+90
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ use integration_tests::pb::{
66
test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output,
77
OutputStream,
88
};
9+
use integration_tests::BoxFuture;
910
use std::error::Error;
11+
use std::task::{Context, Poll};
1012
use std::time::Duration;
1113
use tokio::{net::TcpListener, sync::oneshot};
14+
use tonic::body::Body;
1215
use tonic::metadata::{MetadataMap, MetadataValue};
1316
use tonic::{
1417
transport::{server::TcpIncoming, Endpoint, Server},
@@ -209,6 +212,93 @@ async fn status_from_server_stream_with_source() {
209212
source.downcast_ref::<tonic::transport::Error>().unwrap();
210213
}
211214

215+
#[tokio::test]
216+
async fn status_from_server_stream_with_inferred_status() {
217+
integration_tests::trace_init();
218+
219+
struct Svc;
220+
221+
#[tonic::async_trait]
222+
impl test_stream_server::TestStream for Svc {
223+
type StreamCallStream = Stream<OutputStream>;
224+
225+
async fn stream_call(
226+
&self,
227+
_: Request<InputStream>,
228+
) -> Result<Response<Self::StreamCallStream>, Status> {
229+
let s = tokio_stream::once(Ok(OutputStream {}));
230+
Ok(Response::new(Box::pin(s) as Self::StreamCallStream))
231+
}
232+
}
233+
234+
#[derive(Clone)]
235+
struct TestLayer;
236+
237+
impl<S> tower::Layer<S> for TestLayer {
238+
type Service = TestService;
239+
240+
fn layer(&self, _: S) -> Self::Service {
241+
TestService
242+
}
243+
}
244+
245+
#[derive(Clone)]
246+
struct TestService;
247+
248+
impl tower::Service<http::Request<Body>> for TestService {
249+
type Response = http::Response<Body>;
250+
type Error = std::convert::Infallible;
251+
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
252+
253+
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254+
Poll::Ready(Ok(()))
255+
}
256+
257+
fn call(&mut self, _: http::Request<Body>) -> Self::Future {
258+
Box::pin(async {
259+
Ok(http::Response::builder()
260+
.status(http::StatusCode::BAD_GATEWAY)
261+
.body(Body::empty())
262+
.unwrap())
263+
})
264+
}
265+
}
266+
267+
let svc = test_stream_server::TestStreamServer::new(Svc);
268+
269+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
270+
let addr = listener.local_addr().unwrap();
271+
let incoming: TcpIncoming = TcpIncoming::from(listener).with_nodelay(Some(true));
272+
273+
tokio::spawn(async move {
274+
Server::builder()
275+
.layer(TestLayer)
276+
.add_service(svc)
277+
.serve_with_incoming(incoming)
278+
.await
279+
.unwrap();
280+
});
281+
282+
tokio::time::sleep(Duration::from_millis(100)).await;
283+
284+
let mut client = test_stream_client::TestStreamClient::connect(format!("http://{addr}"))
285+
.await
286+
.unwrap();
287+
288+
let mut stream = client
289+
.stream_call(InputStream {})
290+
.await
291+
.unwrap()
292+
.into_inner();
293+
294+
assert_eq!(
295+
stream.message().await.unwrap_err().code(),
296+
Code::Unavailable
297+
);
298+
299+
assert_eq!(stream.message().await.unwrap(), None);
300+
}
301+
212302
#[tokio::test]
213303
async fn message_and_then_status_from_server_stream() {
214304
integration_tests::trace_init();

tonic/src/codec/decode.rs

+4-6
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,12 @@ impl<T> Stream for Streaming<T> {
400400
}
401401

402402
if ready!(self.inner.poll_frame(cx))?.is_none() {
403-
break;
403+
match self.inner.response() {
404+
Ok(()) => return Poll::Ready(None),
405+
Err(err) => self.inner.state = State::Error(Some(err)),
406+
}
404407
}
405408
}
406-
407-
Poll::Ready(match self.inner.response() {
408-
Ok(()) => None,
409-
Err(err) => Some(Err(err)),
410-
})
411409
}
412410
}
413411

0 commit comments

Comments
 (0)