|
| 1 | +use std::future::Future; |
1 | 2 | use std::io;
|
2 | 3 | use std::path::Path;
|
3 | 4 | use std::pin::Pin;
|
4 | 5 | use std::task::{ready, Context, Poll};
|
5 |
| -use std::{ |
6 |
| - future::Future, |
7 |
| - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, |
8 |
| -}; |
9 | 6 |
|
| 7 | +pub use buffered::{BufferedSocket, WriteBuffer}; |
10 | 8 | use bytes::BufMut;
|
11 | 9 | use cfg_if::cfg_if;
|
12 | 10 |
|
13 |
| -pub use buffered::{BufferedSocket, WriteBuffer}; |
14 |
| - |
15 |
| -use crate::{io::ReadBuf, rt::spawn_blocking}; |
| 11 | +use crate::io::ReadBuf; |
16 | 12 |
|
17 | 13 | mod buffered;
|
18 | 14 |
|
@@ -146,10 +142,7 @@ where
|
146 | 142 | pub trait WithSocket {
|
147 | 143 | type Output;
|
148 | 144 |
|
149 |
| - fn with_socket<S: Socket>( |
150 |
| - self, |
151 |
| - socket: S, |
152 |
| - ) -> impl std::future::Future<Output = Self::Output> + Send; |
| 145 | + fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send; |
153 | 146 | }
|
154 | 147 |
|
155 | 148 | pub struct SocketIntoBox;
|
@@ -193,98 +186,67 @@ pub async fn connect_tcp<Ws: WithSocket>(
|
193 | 186 | port: u16,
|
194 | 187 | with_socket: Ws,
|
195 | 188 | ) -> crate::Result<Ws::Output> {
|
196 |
| - // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those. |
197 |
| - let host = host.trim_matches(&['[', ']'][..]); |
198 |
| - |
199 |
| - let addresses = if let Ok(addr) = host.parse::<Ipv4Addr>() { |
200 |
| - let addr = SocketAddrV4::new(addr, port); |
201 |
| - vec![SocketAddr::V4(addr)].into_iter() |
202 |
| - } else if let Ok(addr) = host.parse::<Ipv6Addr>() { |
203 |
| - let addr = SocketAddrV6::new(addr, port, 0, 0); |
204 |
| - vec![SocketAddr::V6(addr)].into_iter() |
205 |
| - } else { |
206 |
| - let host = host.to_string(); |
207 |
| - spawn_blocking(move || { |
208 |
| - let addr = (host.as_str(), port); |
209 |
| - ToSocketAddrs::to_socket_addrs(&addr) |
210 |
| - }) |
211 |
| - .await? |
212 |
| - }; |
213 |
| - |
214 |
| - let mut last_err = None; |
215 |
| - |
216 |
| - // Loop through all the Socket Addresses that the hostname resolves to |
217 |
| - for socket_addr in addresses { |
218 |
| - match connect_tcp_address(socket_addr).await { |
219 |
| - Ok(stream) => return Ok(with_socket.with_socket(stream).await), |
220 |
| - Err(e) => last_err = Some(e), |
221 |
| - } |
| 189 | + #[cfg(feature = "_rt-tokio")] |
| 190 | + if crate::rt::rt_tokio::available() { |
| 191 | + return Ok(with_socket |
| 192 | + .with_socket(tokio::net::TcpStream::connect((host, port)).await?) |
| 193 | + .await); |
222 | 194 | }
|
223 | 195 |
|
224 |
| - // If we reach this point, it means we failed to connect to any of the addresses. |
225 |
| - // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. |
226 |
| - Err(match last_err { |
227 |
| - Some(err) => err, |
228 |
| - None => io::Error::new( |
229 |
| - io::ErrorKind::AddrNotAvailable, |
230 |
| - "Hostname did not resolve to any addresses", |
231 |
| - ) |
232 |
| - .into(), |
233 |
| - }) |
234 |
| -} |
235 |
| - |
236 |
| -async fn connect_tcp_address(socket_addr: SocketAddr) -> crate::Result<impl Socket> { |
237 | 196 | cfg_if! {
|
238 |
| - if #[cfg(feature = "_rt-tokio")] { |
239 |
| - if crate::rt::rt_tokio::available() { |
240 |
| - use tokio::net::TcpStream; |
241 |
| - |
242 |
| - let stream = TcpStream::connect(socket_addr).await?; |
243 |
| - stream.set_nodelay(true)?; |
244 |
| - |
245 |
| - Ok(stream) |
246 |
| - } else { |
247 |
| - crate::rt::missing_rt(socket_addr) |
248 |
| - } |
249 |
| - } else if #[cfg(feature = "_rt-async-io")] { |
250 |
| - use async_io::Async; |
251 |
| - use std::net::TcpStream; |
252 |
| - |
253 |
| - let stream = Async::<TcpStream>::connect(socket_addr).await?; |
254 |
| - stream.get_ref().set_nodelay(true)?; |
255 |
| - |
256 |
| - Ok(stream) |
| 197 | + if #[cfg(feature = "_rt-async-io")] { |
| 198 | + Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await) |
257 | 199 | } else {
|
258 |
| - crate::rt::missing_rt(socket_addr); |
259 |
| - #[allow(unreachable_code)] |
260 |
| - Ok(()) |
| 200 | + crate::rt::missing_rt((host, port, with_socket)) |
261 | 201 | }
|
262 | 202 | }
|
263 | 203 | }
|
264 | 204 |
|
265 |
| -// Work around `impl Socket`` and 'unability to specify test build cargo feature'. |
266 |
| -// `connect_tcp_address` compilation would fail without this impl with |
267 |
| -// 'cannot infer return type' error. |
268 |
| -impl Socket for () { |
269 |
| - fn try_read(&mut self, _: &mut dyn ReadBuf) -> io::Result<usize> { |
270 |
| - unreachable!() |
271 |
| - } |
| 205 | +/// Open a TCP socket to `host` and `port`. |
| 206 | +/// |
| 207 | +/// If `host` is a hostname, attempt to connect to each address it resolves to. |
| 208 | +/// |
| 209 | +/// This implements the same behavior as [`tokio::net::TcpStream::connect()`]. |
| 210 | +#[cfg(feature = "_rt-async-io")] |
| 211 | +async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> { |
| 212 | + use async_io::Async; |
| 213 | + use std::net::{IpAddr, TcpStream, ToSocketAddrs}; |
272 | 214 |
|
273 |
| - fn try_write(&mut self, _: &[u8]) -> io::Result<usize> { |
274 |
| - unreachable!() |
275 |
| - } |
| 215 | + // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those. |
| 216 | + let host = host.trim_matches(&['[', ']'][..]); |
276 | 217 |
|
277 |
| - fn poll_read_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> { |
278 |
| - unreachable!() |
| 218 | + if let Ok(addr) = host.parse::<IpAddr>() { |
| 219 | + return Ok(Async::<TcpStream>::connect((addr, port)).await?); |
279 | 220 | }
|
280 | 221 |
|
281 |
| - fn poll_write_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> { |
282 |
| - unreachable!() |
283 |
| - } |
| 222 | + let host = host.to_string(); |
| 223 | + |
| 224 | + let addresses = crate::rt::spawn_blocking(move || { |
| 225 | + let addr = (host.as_str(), port); |
| 226 | + ToSocketAddrs::to_socket_addrs(&addr) |
| 227 | + }) |
| 228 | + .await?; |
284 | 229 |
|
285 |
| - fn poll_shutdown(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> { |
286 |
| - unreachable!() |
| 230 | + let mut last_err = None; |
| 231 | + |
| 232 | + // Loop through all the Socket Addresses that the hostname resolves to |
| 233 | + for socket_addr in addresses { |
| 234 | + match Async::<TcpStream>::connect(socket_addr).await { |
| 235 | + Ok(stream) => return Ok(stream), |
| 236 | + Err(e) => last_err = Some(e), |
| 237 | + } |
287 | 238 | }
|
| 239 | + |
| 240 | + // If we reach this point, it means we failed to connect to any of the addresses. |
| 241 | + // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. |
| 242 | + Err(last_err |
| 243 | + .unwrap_or_else(|| { |
| 244 | + io::Error::new( |
| 245 | + io::ErrorKind::AddrNotAvailable, |
| 246 | + "Hostname did not resolve to any addresses", |
| 247 | + ) |
| 248 | + }) |
| 249 | + .into()) |
288 | 250 | }
|
289 | 251 |
|
290 | 252 | /// Connect a Unix Domain Socket at the given path.
|
|
0 commit comments