Skip to content

Commit 0324a54

Browse files
abonanderJosiahParry
authored andcommitted
refactor: tweaks after launchbadge#3791 (launchbadge#4022)
* restore fallback to `async-io` for `connect_tcp()` when `runtime-tokio` feature is enabled * `smol` and `async-global-executor` both use `async-task`, so `JoinHandle` impls can be consolidated * no need for duplicate `yield_now()` impls * delete `impl Socket for ()`
1 parent 06f1edf commit 0324a54

File tree

11 files changed

+117
-225
lines changed

11 files changed

+117
-225
lines changed

.github/workflows/sqlx.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
runs-on: ubuntu-24.04
2323
strategy:
2424
matrix:
25+
# Note: because `async-std` is deprecated, we only check it in a single job to save CI time.
2526
runtime: [ async-std, async-global-executor, smol, tokio ]
2627
tls: [ native-tls, rustls, none ]
2728
timeout-minutes: 30

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-cli/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ features = [
5555
[features]
5656
default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"]
5757

58+
# TLS options
5859
rustls = ["sqlx/tls-rustls"]
5960
native-tls = ["sqlx/tls-native-tls"]
6061

62+
# databases
6163
mysql = ["sqlx/mysql"]
6264
postgres = ["sqlx/postgres"]
6365
sqlite = ["sqlx/sqlite", "_sqlite"]

sqlx-core/Cargo.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ any = []
2020
json = ["serde", "serde_json"]
2121

2222
# for conditional compilation
23-
_rt-async-global-executor = ["async-global-executor", "_rt-async-io"]
23+
_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"]
2424
_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration
2525
_rt-async-std = ["async-std", "_rt-async-io"]
26-
_rt-smol = ["smol", "_rt-async-io"]
26+
_rt-async-task = ["async-task"]
27+
_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"]
2728
_rt-tokio = ["tokio", "tokio-stream"]
29+
2830
_tls-native-tls = ["native-tls"]
2931
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
3032
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
@@ -68,6 +70,8 @@ mac_address = { workspace = true, optional = true }
6870
uuid = { workspace = true, optional = true }
6971

7072
async-io = { version = "2.4.1", optional = true }
73+
async-task = { version = "4.7.1", optional = true }
74+
7175
# work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281
7276
async-fs = { version = "2.1", optional = true }
7377
base64 = { version = "0.22.0", default-features = false, features = ["std"] }

sqlx-core/src/net/socket/mod.rs

Lines changed: 51 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1+
use std::future::Future;
12
use std::io;
23
use std::path::Path;
34
use std::pin::Pin;
45
use std::task::{ready, Context, Poll};
5-
use std::{
6-
future::Future,
7-
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
8-
};
96

7+
pub use buffered::{BufferedSocket, WriteBuffer};
108
use bytes::BufMut;
119
use cfg_if::cfg_if;
1210

13-
pub use buffered::{BufferedSocket, WriteBuffer};
14-
15-
use crate::{io::ReadBuf, rt::spawn_blocking};
11+
use crate::io::ReadBuf;
1612

1713
mod buffered;
1814

@@ -146,10 +142,7 @@ where
146142
pub trait WithSocket {
147143
type Output;
148144

149-
fn with_socket<S: Socket>(
150-
self,
151-
socket: S,
152-
) -> impl std::future::Future<Output = Self::Output> + Send;
145+
fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;
153146
}
154147

155148
pub struct SocketIntoBox;
@@ -193,98 +186,67 @@ pub async fn connect_tcp<Ws: WithSocket>(
193186
port: u16,
194187
with_socket: Ws,
195188
) -> crate::Result<Ws::Output> {
196-
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
197-
let host = host.trim_matches(&['[', ']'][..]);
198-
199-
let addresses = if let Ok(addr) = host.parse::<Ipv4Addr>() {
200-
let addr = SocketAddrV4::new(addr, port);
201-
vec![SocketAddr::V4(addr)].into_iter()
202-
} else if let Ok(addr) = host.parse::<Ipv6Addr>() {
203-
let addr = SocketAddrV6::new(addr, port, 0, 0);
204-
vec![SocketAddr::V6(addr)].into_iter()
205-
} else {
206-
let host = host.to_string();
207-
spawn_blocking(move || {
208-
let addr = (host.as_str(), port);
209-
ToSocketAddrs::to_socket_addrs(&addr)
210-
})
211-
.await?
212-
};
213-
214-
let mut last_err = None;
215-
216-
// Loop through all the Socket Addresses that the hostname resolves to
217-
for socket_addr in addresses {
218-
match connect_tcp_address(socket_addr).await {
219-
Ok(stream) => return Ok(with_socket.with_socket(stream).await),
220-
Err(e) => last_err = Some(e),
221-
}
189+
#[cfg(feature = "_rt-tokio")]
190+
if crate::rt::rt_tokio::available() {
191+
return Ok(with_socket
192+
.with_socket(tokio::net::TcpStream::connect((host, port)).await?)
193+
.await);
222194
}
223195

224-
// If we reach this point, it means we failed to connect to any of the addresses.
225-
// Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
226-
Err(match last_err {
227-
Some(err) => err,
228-
None => io::Error::new(
229-
io::ErrorKind::AddrNotAvailable,
230-
"Hostname did not resolve to any addresses",
231-
)
232-
.into(),
233-
})
234-
}
235-
236-
async fn connect_tcp_address(socket_addr: SocketAddr) -> crate::Result<impl Socket> {
237196
cfg_if! {
238-
if #[cfg(feature = "_rt-tokio")] {
239-
if crate::rt::rt_tokio::available() {
240-
use tokio::net::TcpStream;
241-
242-
let stream = TcpStream::connect(socket_addr).await?;
243-
stream.set_nodelay(true)?;
244-
245-
Ok(stream)
246-
} else {
247-
crate::rt::missing_rt(socket_addr)
248-
}
249-
} else if #[cfg(feature = "_rt-async-io")] {
250-
use async_io::Async;
251-
use std::net::TcpStream;
252-
253-
let stream = Async::<TcpStream>::connect(socket_addr).await?;
254-
stream.get_ref().set_nodelay(true)?;
255-
256-
Ok(stream)
197+
if #[cfg(feature = "_rt-async-io")] {
198+
Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
257199
} else {
258-
crate::rt::missing_rt(socket_addr);
259-
#[allow(unreachable_code)]
260-
Ok(())
200+
crate::rt::missing_rt((host, port, with_socket))
261201
}
262202
}
263203
}
264204

265-
// Work around `impl Socket`` and 'unability to specify test build cargo feature'.
266-
// `connect_tcp_address` compilation would fail without this impl with
267-
// 'cannot infer return type' error.
268-
impl Socket for () {
269-
fn try_read(&mut self, _: &mut dyn ReadBuf) -> io::Result<usize> {
270-
unreachable!()
271-
}
205+
/// Open a TCP socket to `host` and `port`.
206+
///
207+
/// If `host` is a hostname, attempt to connect to each address it resolves to.
208+
///
209+
/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
210+
#[cfg(feature = "_rt-async-io")]
211+
async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
212+
use async_io::Async;
213+
use std::net::{IpAddr, TcpStream, ToSocketAddrs};
272214

273-
fn try_write(&mut self, _: &[u8]) -> io::Result<usize> {
274-
unreachable!()
275-
}
215+
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
216+
let host = host.trim_matches(&['[', ']'][..]);
276217

277-
fn poll_read_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
278-
unreachable!()
218+
if let Ok(addr) = host.parse::<IpAddr>() {
219+
return Ok(Async::<TcpStream>::connect((addr, port)).await?);
279220
}
280221

281-
fn poll_write_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
282-
unreachable!()
283-
}
222+
let host = host.to_string();
223+
224+
let addresses = crate::rt::spawn_blocking(move || {
225+
let addr = (host.as_str(), port);
226+
ToSocketAddrs::to_socket_addrs(&addr)
227+
})
228+
.await?;
284229

285-
fn poll_shutdown(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
286-
unreachable!()
230+
let mut last_err = None;
231+
232+
// Loop through all the Socket Addresses that the hostname resolves to
233+
for socket_addr in addresses {
234+
match Async::<TcpStream>::connect(socket_addr).await {
235+
Ok(stream) => return Ok(stream),
236+
Err(e) => last_err = Some(e),
237+
}
287238
}
239+
240+
// If we reach this point, it means we failed to connect to any of the addresses.
241+
// Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
242+
Err(last_err
243+
.unwrap_or_else(|| {
244+
io::Error::new(
245+
io::ErrorKind::AddrNotAvailable,
246+
"Hostname did not resolve to any addresses",
247+
)
248+
})
249+
.into())
288250
}
289251

290252
/// Connect a Unix Domain Socket at the given path.

0 commit comments

Comments
 (0)