Skip to content

Commit 2181aa8

Browse files
author
Thomas Scholtes
committed
Fix panics and unsafe code
This change fixes panics (routerify#6) and unsafe code (routerify#5). This comes at the cost of an additional copy of the data send through the pipe and having a buffer in the state. Moreover all unsafe code is removed and the need for a custom `Drop` implementation which makes the code overall easier. We also add tests.
1 parent 74c0e25 commit 2181aa8

File tree

4 files changed

+127
-113
lines changed

4 files changed

+127
-113
lines changed

src/lib.rs

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,8 @@ pub fn pipe() -> (PipeWriter, PipeReader) {
4242
let shared_state = Arc::new(Mutex::new(State {
4343
reader_waker: None,
4444
writer_waker: None,
45-
data: None,
46-
done_reading: false,
47-
read: 0,
48-
done_cycle: true,
4945
closed: false,
46+
buffer: Vec::new(),
5047
}));
5148

5249
let w = PipeWriter {
@@ -59,3 +56,67 @@ pub fn pipe() -> (PipeWriter, PipeReader) {
5956

6057
(w, r)
6158
}
59+
60+
#[cfg(test)]
61+
mod test {
62+
use super::pipe;
63+
use std::io;
64+
use tokio::prelude::*;
65+
66+
#[tokio::test]
67+
async fn read_write() {
68+
let (mut writer, mut reader) = pipe();
69+
let data = b"hello world";
70+
71+
let write_handle = tokio::spawn(async move {
72+
writer.write_all(data).await.unwrap();
73+
});
74+
75+
let mut read_buf = Vec::new();
76+
reader.read_to_end(&mut read_buf).await.unwrap();
77+
write_handle.await.unwrap();
78+
79+
assert_eq!(&read_buf, data);
80+
}
81+
82+
#[tokio::test]
83+
async fn eof_when_writer_is_shutdown() {
84+
let (mut writer, mut reader) = pipe();
85+
writer.shutdown().await.unwrap();
86+
let mut buf = [0u8; 8];
87+
let bytes_read = reader.read(&mut buf).await.unwrap();
88+
assert_eq!(bytes_read, 0);
89+
}
90+
91+
#[tokio::test]
92+
async fn broken_pipe_when_reader_is_dropped() {
93+
let (mut writer, reader) = pipe();
94+
drop(reader);
95+
let io_error = writer.write_all(&[0u8; 8]).await.unwrap_err();
96+
assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe);
97+
}
98+
99+
#[tokio::test]
100+
async fn eof_when_writer_is_dropped() {
101+
let (writer, mut reader) = pipe();
102+
drop(writer);
103+
let mut buf = [0u8; 8];
104+
let bytes_read = reader.read(&mut buf).await.unwrap();
105+
assert_eq!(bytes_read, 0);
106+
}
107+
108+
#[tokio::test]
109+
async fn drop_read_exact() {
110+
let (mut writer, mut reader) = pipe();
111+
const BUF_SIZE: usize = 8;
112+
113+
let write_handle = tokio::spawn(async move {
114+
writer.write_all(&mut [0u8; BUF_SIZE]).await.unwrap();
115+
});
116+
117+
let mut buf = [0u8; BUF_SIZE];
118+
reader.read_exact(&mut buf).await.unwrap();
119+
drop(reader);
120+
write_handle.await.unwrap();
121+
}
122+
}

src/reader.rs

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
use crate::state::{Data, State};
1+
use crate::state::State;
22
use std::io;
33
use std::pin::Pin;
4-
use std::ptr;
54
use std::sync::{Arc, Mutex};
65
use std::task::{Context, Poll};
76

@@ -53,7 +52,7 @@ impl PipeReader {
5352
}
5453
};
5554

56-
Ok(state.done_cycle)
55+
Ok(state.buffer.is_empty())
5756
}
5857

5958
fn wake_writer_half(&self, state: &State) {
@@ -62,22 +61,13 @@ impl PipeReader {
6261
}
6362
}
6463

65-
fn copy_data_into_buffer(&self, data: &Data, buf: &mut [u8]) -> usize {
66-
let len = data.len.min(buf.len());
67-
unsafe {
68-
ptr::copy_nonoverlapping(data.ptr, buf.as_mut_ptr(), len);
69-
}
70-
len
71-
}
72-
7364
fn poll_read(
7465
self: Pin<&mut Self>,
7566
cx: &mut Context,
7667
buf: &mut [u8],
7768
) -> Poll<io::Result<usize>> {
78-
let mut state;
79-
match self.state.lock() {
80-
Ok(s) => state = s,
69+
let mut state = match self.state.lock() {
70+
Ok(s) => s,
8171
Err(err) => {
8272
return Poll::Ready(Err(io::Error::new(
8373
io::ErrorKind::Other,
@@ -88,43 +78,24 @@ impl PipeReader {
8878
),
8979
)))
9080
}
91-
}
92-
93-
if state.closed {
94-
return Poll::Ready(Ok(0));
95-
}
96-
97-
return if state.done_cycle {
98-
state.reader_waker = Some(cx.waker().clone());
99-
Poll::Pending
100-
} else {
101-
if let Some(ref data) = state.data {
102-
let copied_bytes_len = self.copy_data_into_buffer(data, buf);
103-
104-
state.data = None;
105-
state.read = copied_bytes_len;
106-
state.done_reading = true;
107-
state.reader_waker = None;
108-
109-
self.wake_writer_half(&*state);
81+
};
11082

111-
Poll::Ready(Ok(copied_bytes_len))
83+
if state.buffer.is_empty() {
84+
if state.closed || Arc::strong_count(&self.state) == 1 {
85+
Poll::Ready(Ok(0))
11286
} else {
87+
self.wake_writer_half(&*state);
11388
state.reader_waker = Some(cx.waker().clone());
11489
Poll::Pending
11590
}
116-
};
117-
}
118-
}
91+
} else {
92+
self.wake_writer_half(&*state);
93+
let size_to_read = state.buffer.len().min(buf.len());
94+
let (to_read, rest) = state.buffer.split_at(size_to_read);
95+
buf[..size_to_read].copy_from_slice(to_read);
96+
state.buffer = rest.to_vec();
11997

120-
impl Drop for PipeReader {
121-
fn drop(&mut self) {
122-
if let Err(err) = self.close() {
123-
log::warn!(
124-
"{}: PipeReader: Failed to close the channel on drop: {}",
125-
env!("CARGO_PKG_NAME"),
126-
err
127-
);
98+
Poll::Ready(Ok(size_to_read))
12899
}
129100
}
130101
}

src/state.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
use std::task::Waker;
22

3+
pub const BUFFER_SIZE: usize = 1024;
4+
35
pub(crate) struct State {
46
pub(crate) reader_waker: Option<Waker>,
57
pub(crate) writer_waker: Option<Waker>,
6-
pub(crate) data: Option<Data>,
7-
pub(crate) done_reading: bool,
8-
pub(crate) read: usize,
9-
pub(crate) done_cycle: bool,
108
pub(crate) closed: bool,
9+
pub(crate) buffer: Vec<u8>,
1110
}
12-
13-
pub(crate) struct Data {
14-
pub(crate) ptr: *const u8,
15-
pub(crate) len: usize,
16-
}
17-
18-
unsafe impl Send for Data {}

src/writer.rs

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
use crate::state::Data;
2-
use crate::state::State;
1+
use crate::state::{State, BUFFER_SIZE};
32
use std::io;
43
use std::pin::Pin;
54
use std::sync::{Arc, Mutex};
@@ -53,7 +52,7 @@ impl PipeWriter {
5352
}
5453
};
5554

56-
Ok(state.done_cycle)
55+
Ok(state.buffer.is_empty())
5756
}
5857

5958
fn wake_reader_half(&self, state: &State) {
@@ -63,9 +62,18 @@ impl PipeWriter {
6362
}
6463

6564
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
66-
let mut state;
67-
match self.state.lock() {
68-
Ok(s) => state = s,
65+
if Arc::strong_count(&self.state) == 1 {
66+
return Poll::Ready(Err(io::Error::new(
67+
io::ErrorKind::BrokenPipe,
68+
format!(
69+
"{}: PipeWriter: The channel is closed",
70+
env!("CARGO_PKG_NAME")
71+
),
72+
)));
73+
}
74+
75+
let mut state = match self.state.lock() {
76+
Ok(s) => s,
6977
Err(err) => {
7078
return Poll::Ready(Err(io::Error::new(
7179
io::ErrorKind::Other,
@@ -76,49 +84,43 @@ impl PipeWriter {
7684
),
7785
)))
7886
}
79-
}
87+
};
8088

81-
if state.closed {
82-
return Poll::Ready(Err(io::Error::new(
83-
io::ErrorKind::BrokenPipe,
84-
format!(
85-
"{}: PipeWriter: The channel is closed",
86-
env!("CARGO_PKG_NAME")
87-
),
88-
)));
89-
}
89+
self.wake_reader_half(&*state);
9090

91-
return if state.done_cycle {
92-
state.data = Some(Data {
93-
ptr: buf.as_ptr(),
94-
len: buf.len(),
95-
});
96-
state.done_cycle = false;
91+
let remaining = BUFFER_SIZE - state.buffer.len();
92+
if remaining == 0 {
9793
state.writer_waker = Some(cx.waker().clone());
98-
99-
self.wake_reader_half(&*state);
100-
10194
Poll::Pending
10295
} else {
103-
if state.done_reading {
104-
let read_bytes_len = state.read;
105-
106-
state.done_cycle = true;
107-
state.read = 0;
108-
state.writer_waker = None;
109-
state.data = None;
110-
state.done_reading = false;
111-
112-
Poll::Ready(Ok(read_bytes_len))
113-
} else {
114-
state.writer_waker = Some(cx.waker().clone());
115-
Poll::Pending
96+
let bytes_to_write = remaining.min(buf.len());
97+
state.buffer.extend_from_slice(&buf[..bytes_to_write]);
98+
Poll::Ready(Ok(bytes_to_write))
99+
}
100+
}
101+
102+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
103+
let mut state = match self.state.lock() {
104+
Ok(s) => s,
105+
Err(err) => {
106+
return Poll::Ready(Err(io::Error::new(
107+
io::ErrorKind::Other,
108+
format!(
109+
"{}: PipeWriter: Failed to lock the channel state: {}",
110+
env!("CARGO_PKG_NAME"),
111+
err
112+
),
113+
)))
116114
}
117115
};
118-
}
119116

120-
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
121-
Poll::Ready(Ok(()))
117+
if state.buffer.is_empty() {
118+
Poll::Ready(Ok(()))
119+
} else {
120+
state.writer_waker = Some(cx.waker().clone());
121+
self.wake_reader_half(&*state);
122+
Poll::Pending
123+
}
122124
}
123125

124126
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
@@ -136,18 +138,6 @@ impl PipeWriter {
136138
}
137139
}
138140

139-
impl Drop for PipeWriter {
140-
fn drop(&mut self) {
141-
if let Err(err) = self.close() {
142-
log::warn!(
143-
"{}: PipeWriter: Failed to close the channel on drop: {}",
144-
env!("CARGO_PKG_NAME"),
145-
err
146-
);
147-
}
148-
}
149-
}
150-
151141
#[cfg(feature = "tokio")]
152142
impl tokio::io::AsyncWrite for PipeWriter {
153143
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {

0 commit comments

Comments
 (0)