Skip to content

Commit 14a9e7f

Browse files
committed
feature(virtio-net): VIRTIO_NET_F_GUEST_CSUM support
1 parent 6580d1a commit 14a9e7f

File tree

1 file changed

+121
-15
lines changed

1 file changed

+121
-15
lines changed

src/drivers/net/virtio/mod.rs

Lines changed: 121 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ use core::mem::{ManuallyDrop, MaybeUninit, transmute};
1616

1717
use smallvec::SmallVec;
1818
use smoltcp::phy::{Checksum, ChecksumCapabilities, DeviceCapabilities};
19-
use smoltcp::wire::{ETHERNET_HEADER_LEN, EthernetFrame, Ipv4Packet, Ipv6Packet};
19+
use smoltcp::wire::{
20+
ETHERNET_HEADER_LEN, EthernetFrame, IpAddress, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket,
21+
UdpPacket,
22+
};
2023
use virtio::net::{ConfigVolatileFieldAccess, Hdr, HdrF};
2124
use virtio::{DeviceConfigSpace, FeatureBits};
2225
use volatile::VolatileRef;
@@ -295,6 +298,101 @@ impl smoltcp::phy::TxToken for TxToken<'_> {
295298
pub struct RxToken<'a> {
296299
recv_vqs: &'a mut RxQueues,
297300
is_mrg_rxbuf_enabled: bool,
301+
checksums: ChecksumCapabilities,
302+
}
303+
304+
impl RxToken<'_> {
305+
/// If we advertised receive checksum offload to smoltcp, we need to validate the packet
306+
/// either by checking its virtio-net headers or checksum. Otherwise, it's smoltcp's responsibility
307+
/// to validate the frame and we can pass the frame directly.
308+
fn is_ethernet_frame_passable(&self, hdr: &Hdr, frame: &[u8]) -> bool {
309+
// Nothing is offloaded to the device. We can pass the frame right off to smoltcp.
310+
if self.checksums.tcp.rx() && self.checksums.udp.rx() {
311+
return true;
312+
}
313+
314+
let Ok(ethernet_frame) = EthernetFrame::new_checked(frame) else {
315+
return false;
316+
};
317+
318+
// We are receiving a frame that was sent by another virtio-net driver on the same host.
319+
// Normally, the device should have filled in the checksum but passed the buffers right along
320+
// instead as checksumming is not necessary for two guests on the same host.
321+
if hdr.flags.contains(virtio::net::HdrF::NEEDS_CSUM) {
322+
return true;
323+
}
324+
325+
// We cannot benefit from the same host optimization but we've promised smoltcp to only pass frames
326+
// that are validated so we need to do the validation ourselves.
327+
match ethernet_frame.ethertype() {
328+
smoltcp::wire::EthernetProtocol::Ipv4 => {
329+
let Ok(ip_packet) = Ipv4Packet::new_checked(ethernet_frame.payload()) else {
330+
return false;
331+
};
332+
333+
// DATA_VALID only validates the outermost packet checksum, which is IPv4 in this case. Thus,
334+
// it does not save us from validating the layer above IP.
335+
Self::is_ip_packet_passable(
336+
ip_packet.next_header(),
337+
ip_packet.payload(),
338+
IpAddress::Ipv4(ip_packet.src_addr()),
339+
IpAddress::Ipv4(ip_packet.dst_addr()),
340+
&self.checksums,
341+
)
342+
}
343+
smoltcp::wire::EthernetProtocol::Ipv6 => {
344+
let Ok(ip_packet) = Ipv6Packet::new_checked(ethernet_frame.payload()) else {
345+
return false;
346+
};
347+
// One level of checksum has been validated and IPv6 headers don't have their own checksums,
348+
// so the validation from the device must have been for the IP protocol.
349+
hdr.flags.contains(virtio::net::HdrF::DATA_VALID) || Self::is_ip_packet_passable(
350+
ip_packet.next_header(),
351+
ip_packet.payload(),
352+
IpAddress::Ipv6(ip_packet.src_addr()),
353+
IpAddress::Ipv6(ip_packet.dst_addr()),
354+
&self.checksums,
355+
)
356+
}
357+
// ARP packets don't have checksums.
358+
smoltcp::wire::EthernetProtocol::Arp
359+
// We should have not taken over the validation of any unknown protocol from smoltcp and may let
360+
// it take care of it.
361+
| smoltcp::wire::EthernetProtocol::Unknown(_) => {
362+
true
363+
}
364+
}
365+
}
366+
367+
fn is_ip_packet_passable(
368+
next_header: IpProtocol,
369+
payload: &[u8],
370+
src_addr: IpAddress,
371+
dst_addr: IpAddress,
372+
checksum_capabilities: &ChecksumCapabilities,
373+
) -> bool {
374+
match next_header {
375+
smoltcp::wire::IpProtocol::Tcp => {
376+
if checksum_capabilities.tcp.rx() {
377+
return true;
378+
}
379+
let Ok(packet) = TcpPacket::new_checked(payload) else {
380+
return false;
381+
};
382+
packet.verify_checksum(&src_addr, &dst_addr)
383+
}
384+
smoltcp::wire::IpProtocol::Udp => {
385+
if checksum_capabilities.udp.rx() {
386+
return true;
387+
}
388+
let Ok(packet) = UdpPacket::new_checked(payload) else {
389+
return false;
390+
};
391+
packet.verify_checksum(&src_addr, &dst_addr)
392+
}
393+
_ => true,
394+
}
395+
}
298396
}
299397

300398
impl smoltcp::phy::RxToken for RxToken<'_> {
@@ -325,18 +423,6 @@ impl smoltcp::phy::RxToken for RxToken<'_> {
325423
};
326424

327425
let mut combined_packets = first_packet;
328-
329-
let first_tkn = buffer_token_from_hdr(
330-
// SAFETY: Box<T> -> Box<MaybeUninit<T>> is sound
331-
unsafe {
332-
transmute::<Box<Hdr, DeviceAlloc>, Box<MaybeUninit<Hdr>, DeviceAlloc>>(first_header)
333-
},
334-
self.recv_vqs.buf_size,
335-
);
336-
self.recv_vqs.vqs[0]
337-
.dispatch(first_tkn, false, BufferType::Direct)
338-
.unwrap();
339-
340426
for _ in 1..num_buffers {
341427
let mut buffer_tkn = self.recv_vqs.get_next().unwrap();
342428
// The descriptor that was meant for the header of another frame was used for a portion of the current frame's contents.
@@ -361,7 +447,24 @@ impl smoltcp::phy::RxToken for RxToken<'_> {
361447
.unwrap();
362448
}
363449

364-
f(&combined_packets)
450+
let res = if self.is_ethernet_frame_passable(&first_header, &combined_packets) {
451+
f(&combined_packets)
452+
} else {
453+
f(&[])
454+
};
455+
456+
let first_tkn = buffer_token_from_hdr(
457+
// SAFETY: Box<T> -> Box<MaybeUninit<T>> is sound
458+
unsafe {
459+
transmute::<Box<Hdr, DeviceAlloc>, Box<MaybeUninit<Hdr>, DeviceAlloc>>(first_header)
460+
},
461+
self.recv_vqs.buf_size,
462+
);
463+
self.recv_vqs.vqs[0]
464+
.dispatch(first_tkn, false, BufferType::Direct)
465+
.unwrap();
466+
467+
res
365468
}
366469
}
367470

@@ -436,6 +539,7 @@ impl smoltcp::phy::Device for VirtioNetDriver {
436539
RxToken {
437540
recv_vqs: &mut self.inner.recv_vqs,
438541
is_mrg_rxbuf_enabled: self.dev_cfg.features.contains(virtio::net::F::MRG_RXBUF),
542+
checksums: self.checksums.clone(),
439543
},
440544
TxToken {
441545
send_vqs: &mut self.inner.send_vqs,
@@ -647,7 +751,9 @@ impl VirtioNetDriver<Uninit> {
647751
// Multiqueue support
648752
| virtio::net::F::MQ
649753
// Checksum calculation can partially be offloaded to the device
650-
| virtio::net::F::CSUM;
754+
| virtio::net::F::CSUM
755+
// Partially checksummed frames can be received
756+
| virtio::net::F::GUEST_CSUM;
651757

652758
// Currently the driver does NOT support the features below.
653759
// In order to provide functionality for these, the driver

0 commit comments

Comments
 (0)