From 2b0965413f62d82bd87d5bd1af1de26e0845d096 Mon Sep 17 00:00:00 2001 From: mivik Date: Mon, 10 Mar 2025 17:10:59 +0800 Subject: [PATCH 1/9] feat: update implementation of UserPtr --- src/ptr.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/ptr.rs b/src/ptr.rs index f05066bd..fa04197c 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -1,7 +1,7 @@ use axerrno::{LinuxError, LinuxResult}; use axhal::paging::{MappingFlags, PageTable}; use axtask::{TaskExtRef, current}; -use memory_addr::{MemoryAddr, PAGE_SIZE_4K, PageIter4K, VirtAddr}; +use memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr, VirtAddrRange}; use core::{alloc::Layout, ffi::CStr, slice}; @@ -21,16 +21,14 @@ fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> return Err(LinuxError::EFAULT); } - // TODO: currently we're doing a very basic and inefficient check, due to - // the fact that AddrSpace does not expose necessary API. let task = current(); let aspace = task.task_ext().aspace.lock(); - let pt = aspace.page_table(); - let page_start = start.align_down_4k(); - let page_end = (start + layout.size()).align_up_4k(); - for page in PageIter4K::new(page_start, page_end).unwrap() { - check_page(pt, page, access_flags)?; + if !aspace.check_region_access( + VirtAddrRange::from_start_size(start, layout.size()), + access_flags, + ) { + return Err(LinuxError::EFAULT); } Ok(()) @@ -40,11 +38,8 @@ fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'stat // TODO: see check_region let task = current(); let aspace = task.task_ext().aspace.lock(); - let pt = aspace.page_table(); let mut page = start.align_down_4k(); - check_page(pt, page, access_flags)?; - page += PAGE_SIZE_4K; let start: *const u8 = start.as_ptr(); let mut len = 0; @@ -52,8 +47,13 @@ fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'stat loop { // SAFETY: Outer caller has provided a pointer to a valid C string. let ptr = unsafe { start.add(len) }; - if ptr == page.as_ptr() { - check_page(pt, page, access_flags)?; + if ptr >= page.as_ptr() { + if !aspace.check_region_access( + VirtAddrRange::from_start_size(page, PAGE_SIZE_4K), + access_flags, + ) { + return Err(LinuxError::EFAULT); + } page += PAGE_SIZE_4K; } From b6e5bae8190468951c4237a0c24596d7d253c5d2 Mon Sep 17 00:00:00 2001 From: mivik Date: Mon, 10 Mar 2025 17:25:18 +0800 Subject: [PATCH 2/9] style: remove unused code --- src/ptr.rs | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/ptr.rs b/src/ptr.rs index fa04197c..9c670d29 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -1,20 +1,10 @@ use axerrno::{LinuxError, LinuxResult}; -use axhal::paging::{MappingFlags, PageTable}; +use axhal::paging::MappingFlags; use axtask::{TaskExtRef, current}; use memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr, VirtAddrRange}; use core::{alloc::Layout, ffi::CStr, slice}; -fn check_page(pt: &PageTable, page: VirtAddr, access_flags: MappingFlags) -> LinuxResult<()> { - let Ok((_, flags, _)) = pt.query(page) else { - return Err(LinuxError::EFAULT); - }; - if !flags.contains(access_flags) { - return Err(LinuxError::EFAULT); - } - Ok(()) -} - fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> LinuxResult<()> { let align = layout.align(); if start.as_usize() & (align - 1) != 0 { From 389271d2f70379271ca5fbee5c822376e9cffaac Mon Sep 17 00:00:00 2001 From: mivik Date: Mon, 10 Mar 2025 21:05:54 +0800 Subject: [PATCH 3/9] feat(ptr): use check_page in check_cstr --- src/ptr.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/ptr.rs b/src/ptr.rs index 9c670d29..ca5a229a 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -1,10 +1,20 @@ use axerrno::{LinuxError, LinuxResult}; -use axhal::paging::MappingFlags; +use axhal::paging::{MappingFlags, PageTable}; use axtask::{TaskExtRef, current}; use memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr, VirtAddrRange}; use core::{alloc::Layout, ffi::CStr, slice}; +fn check_page(pt: &PageTable, page: VirtAddr, access_flags: MappingFlags) -> LinuxResult<()> { + let Ok((_, flags, _)) = pt.query(page) else { + return Err(LinuxError::EFAULT); + }; + if !flags.contains(access_flags) { + return Err(LinuxError::EFAULT); + } + Ok(()) +} + fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> LinuxResult<()> { let align = layout.align(); if start.as_usize() & (align - 1) != 0 { @@ -28,6 +38,7 @@ fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'stat // TODO: see check_region let task = current(); let aspace = task.task_ext().aspace.lock(); + let pt = aspace.page_table(); let mut page = start.align_down_4k(); @@ -38,12 +49,7 @@ fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'stat // SAFETY: Outer caller has provided a pointer to a valid C string. let ptr = unsafe { start.add(len) }; if ptr >= page.as_ptr() { - if !aspace.check_region_access( - VirtAddrRange::from_start_size(page, PAGE_SIZE_4K), - access_flags, - ) { - return Err(LinuxError::EFAULT); - } + check_page(pt, page, access_flags)?; page += PAGE_SIZE_4K; } From 29793082f294cf80578b783a4f876051ab10dbc4 Mon Sep 17 00:00:00 2001 From: mivik Date: Tue, 11 Mar 2025 22:14:19 +0800 Subject: [PATCH 4/9] feat(ptr): force allocate pages on get ptr --- Cargo.lock | 1 + Cargo.toml | 3 ++- src/ptr.rs | 74 +++++++++++++++++++++++++++++++++--------------------- 3 files changed, 48 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f0d197b..168a5df9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1419,6 +1419,7 @@ dependencies = [ "memory_addr", "num_enum", "numeric-enum-macro", + "percpu", "spin", "syscalls", "toml_edit", diff --git a/Cargo.toml b/Cargo.toml index 838ca301..a3ec7740 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ xmas-elf = "0.9" spin = "0.9" crate_interface = "0.1" bitflags = "2.6" +percpu = "0.2.0" kernel-elf-parser = "0.3" num_enum = { version = "0.7", default-features = false } @@ -44,4 +45,4 @@ page_table_entry = { git = "https://github.com/yfblock/page_table_multiarch.git" x86 = "0.52" [build-dependencies] -toml_edit = "0.22" \ No newline at end of file +toml_edit = "0.22" diff --git a/src/ptr.rs b/src/ptr.rs index ca5a229a..67e6540b 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -1,19 +1,11 @@ use axerrno::{LinuxError, LinuxResult}; -use axhal::paging::{MappingFlags, PageTable}; +use axhal::paging::MappingFlags; use axtask::{TaskExtRef, current}; -use memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr, VirtAddrRange}; +use memory_addr::{MemoryAddr, PAGE_SIZE_4K, PageIter4K, VirtAddr, VirtAddrRange}; use core::{alloc::Layout, ffi::CStr, slice}; -fn check_page(pt: &PageTable, page: VirtAddr, access_flags: MappingFlags) -> LinuxResult<()> { - let Ok((_, flags, _)) = pt.query(page) else { - return Err(LinuxError::EFAULT); - }; - if !flags.contains(access_flags) { - return Err(LinuxError::EFAULT); - } - Ok(()) -} +use crate::mm::access_user_memory; fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> LinuxResult<()> { let align = layout.align(); @@ -31,34 +23,58 @@ fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> return Err(LinuxError::EFAULT); } + // Now force each page to be loaded into memory. + access_user_memory(|| { + let page_start = start.align_down_4k(); + let page_end = (start + layout.size()).align_up_4k(); + for page in PageIter4K::new(page_start, page_end).unwrap() { + // SAFETY: The page is valid and we've checked the access flags. + unsafe { page.as_ptr_of::().read_volatile() }; + } + }); + Ok(()) } fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'static CStr> { - // TODO: see check_region - let task = current(); - let aspace = task.task_ext().aspace.lock(); - let pt = aspace.page_table(); - let mut page = start.align_down_4k(); let start: *const u8 = start.as_ptr(); let mut len = 0; - loop { - // SAFETY: Outer caller has provided a pointer to a valid C string. - let ptr = unsafe { start.add(len) }; - if ptr >= page.as_ptr() { - check_page(pt, page, access_flags)?; - page += PAGE_SIZE_4K; + access_user_memory(|| { + loop { + // SAFETY: Outer caller has provided a pointer to a valid C string. + let ptr = unsafe { start.add(len) }; + if ptr >= page.as_ptr() { + // We cannot prepare `aspace` outside of the loop, since holding + // aspace requires a mutex which would be required on page + // fault, and page faults can trigger inside the loop. + + // TODO: this is inefficient, but we have to do this instead of + // querying the page table since the page might has not been + // allocated yet. + let task = current(); + let aspace = task.task_ext().aspace.lock(); + if !aspace.check_region_access( + VirtAddrRange::from_start_size(page, PAGE_SIZE_4K), + access_flags, + ) { + return Err(LinuxError::EFAULT); + } + + page += PAGE_SIZE_4K; + } + + // This might trigger a page fault + // SAFETY: The pointer is valid and points to a valid memory region. + if unsafe { *ptr } == 0 { + break; + } + len += 1; } - - // SAFETY: The pointer is valid and points to a valid memory region. - if unsafe { *ptr } == 0 { - break; - } - len += 1; - } + Ok(()) + })?; // SAFETY: We've checked that the memory region contains a valid C string. Ok(unsafe { CStr::from_bytes_with_nul_unchecked(slice::from_raw_parts(start, len + 1)) }) From a7639c57e38df2afc69bd22d5e14614b994c00ae Mon Sep 17 00:00:00 2001 From: mivik Date: Tue, 11 Mar 2025 22:23:37 +0800 Subject: [PATCH 5/9] fix: upload mm.rs --- src/mm.rs | 46 +++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/mm.rs b/src/mm.rs index 6ba3b76b..fd30ea48 100644 --- a/src/mm.rs +++ b/src/mm.rs @@ -171,28 +171,40 @@ pub fn load_user_app( Ok((entry, user_sp)) } +#[percpu::def_percpu] +static mut ACCESSING_USER_MEM: bool = false; + +/// Enables scoped access into user memory, allowing page faults to occur inside +/// kernel. +pub fn access_user_memory(f: impl FnOnce() -> R) -> R { + ACCESSING_USER_MEM.write_current(true); + let result = f(); + ACCESSING_USER_MEM.write_current(false); + result +} + #[register_trap_handler(PAGE_FAULT)] fn handle_page_fault(vaddr: VirtAddr, access_flags: MappingFlags, is_user: bool) -> bool { warn!( "Page fault at {:#x}, access_flags: {:#x?}", vaddr, access_flags ); - if is_user { - if !axtask::current() - .task_ext() - .aspace - .lock() - .handle_page_fault(vaddr, access_flags) - { - warn!( - "{}: segmentation fault at {:#x}, exit!", - axtask::current().id_name(), - vaddr - ); - axtask::exit(-1); - } - true - } else { - false + if !is_user && !ACCESSING_USER_MEM.read_current() { + return false; + } + + if !axtask::current() + .task_ext() + .aspace + .lock() + .handle_page_fault(vaddr, access_flags) + { + warn!( + "{}: segmentation fault at {:#x}, exit!", + axtask::current().id_name(), + vaddr + ); + axtask::exit(-1); } + true } From 1055800a245ac2146f28f679e28a8df5f7f27e5c Mon Sep 17 00:00:00 2001 From: mivik Date: Tue, 11 Mar 2025 22:57:34 +0800 Subject: [PATCH 6/9] fix: disable preemption inside access_user_memory --- src/mm.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mm.rs b/src/mm.rs index fd30ea48..df31ce38 100644 --- a/src/mm.rs +++ b/src/mm.rs @@ -59,7 +59,7 @@ fn map_elf( &interp_elf, axconfig::plat::USER_INTERP_BASE, Some(uspace_base as isize), - uspace_base, + // uspace_base, ) .map_err(|_| AxError::InvalidData)?; // Set the first argument to the path of the user app. @@ -177,10 +177,12 @@ static mut ACCESSING_USER_MEM: bool = false; /// Enables scoped access into user memory, allowing page faults to occur inside /// kernel. pub fn access_user_memory(f: impl FnOnce() -> R) -> R { - ACCESSING_USER_MEM.write_current(true); - let result = f(); - ACCESSING_USER_MEM.write_current(false); - result + ACCESSING_USER_MEM.with_current(|v| { + *v = true; + let result = f(); + *v = false; + result + }) } #[register_trap_handler(PAGE_FAULT)] From f2d084861a2f65f00e124bd8b9be592063e01b7d Mon Sep 17 00:00:00 2001 From: mivik Date: Wed, 12 Mar 2025 10:39:50 +0800 Subject: [PATCH 7/9] feat(ptr): add get_as_null_terminated to replace get_as_cstr --- Cargo.lock | 1 + Cargo.toml | 1 + src/mm.rs | 2 +- src/ptr.rs | 68 ++++++++++++++++++++++++++++++-------- src/syscall_imp/fs/ctl.rs | 8 ++--- src/syscall_imp/fs/io.rs | 2 +- src/syscall_imp/fs/stat.rs | 2 +- src/syscall_imp/mod.rs | 4 +-- 8 files changed, 66 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e6c31a09..3821fcc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1421,6 +1421,7 @@ dependencies = [ "numeric-enum-macro", "percpu", "spin", + "static_assertions", "syscalls", "toml_edit", "x86", diff --git a/Cargo.toml b/Cargo.toml index a3ec7740..0a77bf51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ kernel-elf-parser = "0.3" num_enum = { version = "0.7", default-features = false } syscalls = { version = "0.6", default-features = false } numeric-enum-macro = "0.2.0" +static_assertions = "1.1.0" axconfig = { git = "https://github.com/oscomp/arceos.git" } axfs = { git = "https://github.com/oscomp/arceos.git" } diff --git a/src/mm.rs b/src/mm.rs index df31ce38..2086966e 100644 --- a/src/mm.rs +++ b/src/mm.rs @@ -59,7 +59,7 @@ fn map_elf( &interp_elf, axconfig::plat::USER_INTERP_BASE, Some(uspace_base as isize), - // uspace_base, + uspace_base, ) .map_err(|_| AxError::InvalidData)?; // Set the first argument to the path of the user app. diff --git a/src/ptr.rs b/src/ptr.rs index 67e6540b..79328dbb 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -3,7 +3,7 @@ use axhal::paging::MappingFlags; use axtask::{TaskExtRef, current}; use memory_addr::{MemoryAddr, PAGE_SIZE_4K, PageIter4K, VirtAddr, VirtAddrRange}; -use core::{alloc::Layout, ffi::CStr, slice}; +use core::{alloc::Layout, ffi::c_char, mem, slice, str}; use crate::mm::access_user_memory; @@ -36,17 +36,28 @@ fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> Ok(()) } -fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'static CStr> { +fn check_null_terminated( + start: VirtAddr, + access_flags: MappingFlags, +) -> LinuxResult<&'static [T]> { + let align = Layout::new::().align(); + if start.as_usize() & (align - 1) != 0 { + return Err(LinuxError::EFAULT); + } + + let zero = T::default(); + let mut page = start.align_down_4k(); - let start: *const u8 = start.as_ptr(); + let start = start.as_ptr_of::(); let mut len = 0; access_user_memory(|| { loop { - // SAFETY: Outer caller has provided a pointer to a valid C string. + // SAFETY: This won't overflow the address space since we'll check + // it below. let ptr = unsafe { start.add(len) }; - if ptr >= page.as_ptr() { + while ptr as usize >= page.as_ptr() as usize { // We cannot prepare `aspace` outside of the loop, since holding // aspace requires a mutex which would be required on page // fault, and page faults can trigger inside the loop. @@ -68,7 +79,7 @@ fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'stat // This might trigger a page fault // SAFETY: The pointer is valid and points to a valid memory region. - if unsafe { *ptr } == 0 { + if unsafe { ptr.read_volatile() } == zero { break; } len += 1; @@ -77,7 +88,7 @@ fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'stat })?; // SAFETY: We've checked that the memory region contains a valid C string. - Ok(unsafe { CStr::from_bytes_with_nul_unchecked(slice::from_raw_parts(start, len + 1)) }) + Ok(unsafe { slice::from_raw_parts(start, len) }) } /// A trait representing a pointer in user space, which can be converted to a @@ -133,12 +144,6 @@ pub trait PtrWrapper: Sized { )?; unsafe { Ok(self.into_inner()) } } - - /// Get the pointer as `&CStr`, validating the memory region specified by - /// the size of a C string. - fn get_as_cstr(self) -> LinuxResult<&'static CStr> { - check_cstr(self.address(), Self::ACCESS_FLAGS) - } } /// A pointer to user space memory. @@ -167,6 +172,19 @@ impl PtrWrapper for UserPtr { } } +impl UserPtr { + /// Get the pointer as `&mut [T]`, terminated by a null value, validating + /// the memory region. + pub fn get_as_null_terminated(self) -> LinuxResult<&'static mut [T]> + where + T: Eq + Default, + { + let slice = check_null_terminated::(self.address(), Self::ACCESS_FLAGS)?; + // SAFETY: The pointer is mutable and we've validated the memory region. + unsafe { Ok(slice::from_raw_parts_mut(slice.as_ptr() as _, slice.len())) } + } +} + /// An immutable pointer to user space memory. /// /// See [`PtrWrapper`] for more details. @@ -192,3 +210,27 @@ impl PtrWrapper for UserConstPtr { VirtAddr::from_ptr_of(self.0) } } + +impl UserConstPtr { + /// Get the pointer as `&[T]`, terminated by a null value, validating the + /// memory region. + pub fn get_as_null_terminated(self) -> LinuxResult<&'static [T]> + where + T: Eq + Default, + { + check_null_terminated::(self.address(), Self::ACCESS_FLAGS) + } +} + +static_assertions::const_assert_eq!(size_of::(), size_of::()); + +impl UserConstPtr { + /// Get the pointer as `&str`, validating the memory region. + pub fn get_as_str(self) -> LinuxResult> { + let slice = self.get_as_null_terminated()?; + // SAFETY: c_char is u8 + let slice = unsafe { mem::transmute::<&[c_char], &[u8]>(slice) }; + + Ok(str::from_utf8(slice).ok()) + } +} diff --git a/src/syscall_imp/fs/ctl.rs b/src/syscall_imp/fs/ctl.rs index 75391ac7..cfc8f931 100644 --- a/src/syscall_imp/fs/ctl.rs +++ b/src/syscall_imp/fs/ctl.rs @@ -245,8 +245,8 @@ pub(crate) fn sys_linkat( new_path: UserConstPtr, flags: i32, ) -> i32 { - let old_path = syscall_unwrap!(old_path.get_as_cstr()); - let new_path = syscall_unwrap!(new_path.get_as_cstr()); + let old_path = syscall_unwrap!(old_path.get_as_null_terminated()); + let new_path = syscall_unwrap!(new_path.get_as_null_terminated()); if flags != 0 { warn!("Unsupported flags: {flags}"); @@ -281,7 +281,7 @@ pub(crate) fn sys_linkat( /// flags: can be 0 or AT_REMOVEDIR /// return 0 when success, else return -1 pub fn sys_unlinkat(dir_fd: isize, path: UserConstPtr, flags: usize) -> isize { - let path = syscall_unwrap!(path.get_as_cstr()); + let path = syscall_unwrap!(path.get_as_null_terminated()); const AT_REMOVEDIR: usize = 0x200; @@ -316,7 +316,7 @@ pub(crate) fn sys_getcwd(buf: UserPtr, size: usize) -> *mut c_char { syscall_body!( sys_getcwd, Ok(arceos_posix_api::sys_getcwd( - buf.get_as_cstr()?.as_ptr() as _, + buf.get_as_null_terminated()?.as_ptr() as _, size )) ) diff --git a/src/syscall_imp/fs/io.rs b/src/syscall_imp/fs/io.rs index 58be3956..4876b1cf 100644 --- a/src/syscall_imp/fs/io.rs +++ b/src/syscall_imp/fs/io.rs @@ -28,7 +28,7 @@ pub(crate) fn sys_openat( flags: i32, modes: mode_t, ) -> isize { - let path = syscall_unwrap!(path.get_as_cstr()); + let path = syscall_unwrap!(path.get_as_null_terminated()); api::sys_openat(dirfd, path.as_ptr(), flags, modes) as _ } diff --git a/src/syscall_imp/fs/stat.rs b/src/syscall_imp/fs/stat.rs index 61490c6a..ae09e9f4 100644 --- a/src/syscall_imp/fs/stat.rs +++ b/src/syscall_imp/fs/stat.rs @@ -100,7 +100,7 @@ pub fn sys_fstatat( _flags: i32, ) -> i32 { syscall_body!(sys_fstatat, { - let path = path.get_as_cstr()?; + let path = path.get_as_null_terminated()?; let path = arceos_posix_api::handle_file_path(dir_fd, Some(path.as_ptr() as _), false)?; let kstatbuf = kstatbuf.get()?; diff --git a/src/syscall_imp/mod.rs b/src/syscall_imp/mod.rs index 4b165e81..25334851 100644 --- a/src/syscall_imp/mod.rs +++ b/src/syscall_imp/mod.rs @@ -8,7 +8,7 @@ mod utils; use core::ffi::c_char; use crate::{ - ptr::{PtrWrapper, UserConstPtr}, + ptr::UserConstPtr, task::{time_stat_from_kernel_to_user, time_stat_from_user_to_kernel}, }; use axerrno::LinuxError; @@ -62,7 +62,7 @@ macro_rules! syscall_body { } pub(crate) fn read_path_str(path: UserConstPtr) -> Result<&'static str, LinuxError> { - path.get_as_cstr()?.to_str().map_err(|_| { + path.get_as_str()?.ok_or_else(|| { warn!("Invalid path"); LinuxError::EFAULT }) From f04ede61896761987a7c4dec8e5cc3c70da5ad3a Mon Sep 17 00:00:00 2001 From: mivik Date: Wed, 12 Mar 2025 10:43:28 +0800 Subject: [PATCH 8/9] feat(ptr): add nullable helper --- src/ptr.rs | 8 ++++++++ src/syscall_imp/task/thread.rs | 9 ++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/ptr.rs b/src/ptr.rs index 79328dbb..3a6d99fc 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -144,6 +144,14 @@ pub trait PtrWrapper: Sized { )?; unsafe { Ok(self.into_inner()) } } + + fn nullable(self, f: impl FnOnce(Self) -> LinuxResult) -> LinuxResult> { + if self.address().as_ptr().is_null() { + Ok(None) + } else { + f(self).map(Some) + } + } } /// A pointer to user space memory. diff --git a/src/syscall_imp/task/thread.rs b/src/syscall_imp/task/thread.rs index 78a7ff9d..65d69c0a 100644 --- a/src/syscall_imp/task/thread.rs +++ b/src/syscall_imp/task/thread.rs @@ -1,4 +1,7 @@ -use core::ffi::{c_char, c_int}; +use core::{ + ffi::{c_char, c_int}, + ptr, +}; use axerrno::LinuxError; use axtask::{TaskExtRef, current, yield_now}; @@ -144,9 +147,9 @@ pub(crate) fn sys_clone( pub(crate) fn sys_wait4(pid: i32, exit_code_ptr: UserPtr, option: u32) -> isize { let option_flag = WaitFlags::from_bits(option).unwrap(); syscall_body!(sys_wait4, { - let exit_code_ptr = exit_code_ptr.get()?; + let exit_code_ptr = exit_code_ptr.nullable(UserPtr::get)?; loop { - let answer = wait_pid(pid, exit_code_ptr); + let answer = wait_pid(pid, exit_code_ptr.unwrap_or_else(ptr::null_mut)); match answer { Ok(pid) => { return Ok(pid as isize); From a810abdffe566b18dd5999fdcbf33dbc5c3eeaca Mon Sep 17 00:00:00 2001 From: mivik Date: Wed, 12 Mar 2025 12:32:37 +0800 Subject: [PATCH 9/9] style(ptr): refactor implementation --- src/ptr.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/ptr.rs b/src/ptr.rs index 3a6d99fc..c1695220 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -39,7 +39,7 @@ fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> fn check_null_terminated( start: VirtAddr, access_flags: MappingFlags, -) -> LinuxResult<&'static [T]> { +) -> LinuxResult<(*const T, usize)> { let align = Layout::new::().align(); if start.as_usize() & (align - 1) != 0 { return Err(LinuxError::EFAULT); @@ -87,8 +87,7 @@ fn check_null_terminated( Ok(()) })?; - // SAFETY: We've checked that the memory region contains a valid C string. - Ok(unsafe { slice::from_raw_parts(start, len) }) + Ok((start, len)) } /// A trait representing a pointer in user space, which can be converted to a @@ -187,9 +186,9 @@ impl UserPtr { where T: Eq + Default, { - let slice = check_null_terminated::(self.address(), Self::ACCESS_FLAGS)?; - // SAFETY: The pointer is mutable and we've validated the memory region. - unsafe { Ok(slice::from_raw_parts_mut(slice.as_ptr() as _, slice.len())) } + let (ptr, len) = check_null_terminated::(self.address(), Self::ACCESS_FLAGS)?; + // SAFETY: We've validated the memory region. + unsafe { Ok(slice::from_raw_parts_mut(ptr as *mut _, len)) } } } @@ -226,7 +225,9 @@ impl UserConstPtr { where T: Eq + Default, { - check_null_terminated::(self.address(), Self::ACCESS_FLAGS) + let (ptr, len) = check_null_terminated::(self.address(), Self::ACCESS_FLAGS)?; + // SAFETY: We've validated the memory region. + unsafe { Ok(slice::from_raw_parts(ptr, len)) } } }