diff --git a/Cargo.lock b/Cargo.lock index 092427e2..3821fcc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1419,7 +1419,9 @@ dependencies = [ "memory_addr", "num_enum", "numeric-enum-macro", + "percpu", "spin", + "static_assertions", "syscalls", "toml_edit", "x86", diff --git a/Cargo.toml b/Cargo.toml index 838ca301..0a77bf51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,11 +18,13 @@ xmas-elf = "0.9" spin = "0.9" crate_interface = "0.1" bitflags = "2.6" +percpu = "0.2.0" kernel-elf-parser = "0.3" num_enum = { version = "0.7", default-features = false } syscalls = { version = "0.6", default-features = false } numeric-enum-macro = "0.2.0" +static_assertions = "1.1.0" axconfig = { git = "https://github.com/oscomp/arceos.git" } axfs = { git = "https://github.com/oscomp/arceos.git" } @@ -44,4 +46,4 @@ page_table_entry = { git = "https://github.com/yfblock/page_table_multiarch.git" x86 = "0.52" [build-dependencies] -toml_edit = "0.22" \ No newline at end of file +toml_edit = "0.22" diff --git a/src/mm.rs b/src/mm.rs index 6ba3b76b..2086966e 100644 --- a/src/mm.rs +++ b/src/mm.rs @@ -171,28 +171,42 @@ pub fn load_user_app( Ok((entry, user_sp)) } +#[percpu::def_percpu] +static mut ACCESSING_USER_MEM: bool = false; + +/// Enables scoped access into user memory, allowing page faults to occur inside +/// kernel. +pub fn access_user_memory(f: impl FnOnce() -> R) -> R { + ACCESSING_USER_MEM.with_current(|v| { + *v = true; + let result = f(); + *v = false; + result + }) +} + #[register_trap_handler(PAGE_FAULT)] fn handle_page_fault(vaddr: VirtAddr, access_flags: MappingFlags, is_user: bool) -> bool { warn!( "Page fault at {:#x}, access_flags: {:#x?}", vaddr, access_flags ); - if is_user { - if !axtask::current() - .task_ext() - .aspace - .lock() - .handle_page_fault(vaddr, access_flags) - { - warn!( - "{}: segmentation fault at {:#x}, exit!", - axtask::current().id_name(), - vaddr - ); - axtask::exit(-1); - } - true - } else { - false + if !is_user && !ACCESSING_USER_MEM.read_current() { + return false; + } + + if !axtask::current() + .task_ext() + .aspace + .lock() + .handle_page_fault(vaddr, access_flags) + { + warn!( + "{}: segmentation fault at {:#x}, exit!", + axtask::current().id_name(), + vaddr + ); + axtask::exit(-1); } + true } diff --git a/src/ptr.rs b/src/ptr.rs index f05066bd..c1695220 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -1,19 +1,11 @@ use axerrno::{LinuxError, LinuxResult}; -use axhal::paging::{MappingFlags, PageTable}; +use axhal::paging::MappingFlags; use axtask::{TaskExtRef, current}; -use memory_addr::{MemoryAddr, PAGE_SIZE_4K, PageIter4K, VirtAddr}; +use memory_addr::{MemoryAddr, PAGE_SIZE_4K, PageIter4K, VirtAddr, VirtAddrRange}; -use core::{alloc::Layout, ffi::CStr, slice}; +use core::{alloc::Layout, ffi::c_char, mem, slice, str}; -fn check_page(pt: &PageTable, page: VirtAddr, access_flags: MappingFlags) -> LinuxResult<()> { - let Ok((_, flags, _)) = pt.query(page) else { - return Err(LinuxError::EFAULT); - }; - if !flags.contains(access_flags) { - return Err(LinuxError::EFAULT); - } - Ok(()) -} +use crate::mm::access_user_memory; fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> LinuxResult<()> { let align = layout.align(); @@ -21,51 +13,81 @@ fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> return Err(LinuxError::EFAULT); } - // TODO: currently we're doing a very basic and inefficient check, due to - // the fact that AddrSpace does not expose necessary API. let task = current(); let aspace = task.task_ext().aspace.lock(); - let pt = aspace.page_table(); - let page_start = start.align_down_4k(); - let page_end = (start + layout.size()).align_up_4k(); - for page in PageIter4K::new(page_start, page_end).unwrap() { - check_page(pt, page, access_flags)?; + if !aspace.check_region_access( + VirtAddrRange::from_start_size(start, layout.size()), + access_flags, + ) { + return Err(LinuxError::EFAULT); } + // Now force each page to be loaded into memory. + access_user_memory(|| { + let page_start = start.align_down_4k(); + let page_end = (start + layout.size()).align_up_4k(); + for page in PageIter4K::new(page_start, page_end).unwrap() { + // SAFETY: The page is valid and we've checked the access flags. + unsafe { page.as_ptr_of::().read_volatile() }; + } + }); + Ok(()) } -fn check_cstr(start: VirtAddr, access_flags: MappingFlags) -> LinuxResult<&'static CStr> { - // TODO: see check_region - let task = current(); - let aspace = task.task_ext().aspace.lock(); - let pt = aspace.page_table(); +fn check_null_terminated( + start: VirtAddr, + access_flags: MappingFlags, +) -> LinuxResult<(*const T, usize)> { + let align = Layout::new::().align(); + if start.as_usize() & (align - 1) != 0 { + return Err(LinuxError::EFAULT); + } + + let zero = T::default(); let mut page = start.align_down_4k(); - check_page(pt, page, access_flags)?; - page += PAGE_SIZE_4K; - let start: *const u8 = start.as_ptr(); + let start = start.as_ptr_of::(); let mut len = 0; - loop { - // SAFETY: Outer caller has provided a pointer to a valid C string. - let ptr = unsafe { start.add(len) }; - if ptr == page.as_ptr() { - check_page(pt, page, access_flags)?; - page += PAGE_SIZE_4K; + access_user_memory(|| { + loop { + // SAFETY: This won't overflow the address space since we'll check + // it below. + let ptr = unsafe { start.add(len) }; + while ptr as usize >= page.as_ptr() as usize { + // We cannot prepare `aspace` outside of the loop, since holding + // aspace requires a mutex which would be required on page + // fault, and page faults can trigger inside the loop. + + // TODO: this is inefficient, but we have to do this instead of + // querying the page table since the page might has not been + // allocated yet. + let task = current(); + let aspace = task.task_ext().aspace.lock(); + if !aspace.check_region_access( + VirtAddrRange::from_start_size(page, PAGE_SIZE_4K), + access_flags, + ) { + return Err(LinuxError::EFAULT); + } + + page += PAGE_SIZE_4K; + } + + // This might trigger a page fault + // SAFETY: The pointer is valid and points to a valid memory region. + if unsafe { ptr.read_volatile() } == zero { + break; + } + len += 1; } + Ok(()) + })?; - // SAFETY: The pointer is valid and points to a valid memory region. - if unsafe { *ptr } == 0 { - break; - } - len += 1; - } - - // SAFETY: We've checked that the memory region contains a valid C string. - Ok(unsafe { CStr::from_bytes_with_nul_unchecked(slice::from_raw_parts(start, len + 1)) }) + Ok((start, len)) } /// A trait representing a pointer in user space, which can be converted to a @@ -122,10 +144,12 @@ pub trait PtrWrapper: Sized { unsafe { Ok(self.into_inner()) } } - /// Get the pointer as `&CStr`, validating the memory region specified by - /// the size of a C string. - fn get_as_cstr(self) -> LinuxResult<&'static CStr> { - check_cstr(self.address(), Self::ACCESS_FLAGS) + fn nullable(self, f: impl FnOnce(Self) -> LinuxResult) -> LinuxResult> { + if self.address().as_ptr().is_null() { + Ok(None) + } else { + f(self).map(Some) + } } } @@ -155,6 +179,19 @@ impl PtrWrapper for UserPtr { } } +impl UserPtr { + /// Get the pointer as `&mut [T]`, terminated by a null value, validating + /// the memory region. + pub fn get_as_null_terminated(self) -> LinuxResult<&'static mut [T]> + where + T: Eq + Default, + { + let (ptr, len) = check_null_terminated::(self.address(), Self::ACCESS_FLAGS)?; + // SAFETY: We've validated the memory region. + unsafe { Ok(slice::from_raw_parts_mut(ptr as *mut _, len)) } + } +} + /// An immutable pointer to user space memory. /// /// See [`PtrWrapper`] for more details. @@ -180,3 +217,29 @@ impl PtrWrapper for UserConstPtr { VirtAddr::from_ptr_of(self.0) } } + +impl UserConstPtr { + /// Get the pointer as `&[T]`, terminated by a null value, validating the + /// memory region. + pub fn get_as_null_terminated(self) -> LinuxResult<&'static [T]> + where + T: Eq + Default, + { + let (ptr, len) = check_null_terminated::(self.address(), Self::ACCESS_FLAGS)?; + // SAFETY: We've validated the memory region. + unsafe { Ok(slice::from_raw_parts(ptr, len)) } + } +} + +static_assertions::const_assert_eq!(size_of::(), size_of::()); + +impl UserConstPtr { + /// Get the pointer as `&str`, validating the memory region. + pub fn get_as_str(self) -> LinuxResult> { + let slice = self.get_as_null_terminated()?; + // SAFETY: c_char is u8 + let slice = unsafe { mem::transmute::<&[c_char], &[u8]>(slice) }; + + Ok(str::from_utf8(slice).ok()) + } +} diff --git a/src/syscall_imp/fs/ctl.rs b/src/syscall_imp/fs/ctl.rs index 75391ac7..cfc8f931 100644 --- a/src/syscall_imp/fs/ctl.rs +++ b/src/syscall_imp/fs/ctl.rs @@ -245,8 +245,8 @@ pub(crate) fn sys_linkat( new_path: UserConstPtr, flags: i32, ) -> i32 { - let old_path = syscall_unwrap!(old_path.get_as_cstr()); - let new_path = syscall_unwrap!(new_path.get_as_cstr()); + let old_path = syscall_unwrap!(old_path.get_as_null_terminated()); + let new_path = syscall_unwrap!(new_path.get_as_null_terminated()); if flags != 0 { warn!("Unsupported flags: {flags}"); @@ -281,7 +281,7 @@ pub(crate) fn sys_linkat( /// flags: can be 0 or AT_REMOVEDIR /// return 0 when success, else return -1 pub fn sys_unlinkat(dir_fd: isize, path: UserConstPtr, flags: usize) -> isize { - let path = syscall_unwrap!(path.get_as_cstr()); + let path = syscall_unwrap!(path.get_as_null_terminated()); const AT_REMOVEDIR: usize = 0x200; @@ -316,7 +316,7 @@ pub(crate) fn sys_getcwd(buf: UserPtr, size: usize) -> *mut c_char { syscall_body!( sys_getcwd, Ok(arceos_posix_api::sys_getcwd( - buf.get_as_cstr()?.as_ptr() as _, + buf.get_as_null_terminated()?.as_ptr() as _, size )) ) diff --git a/src/syscall_imp/fs/io.rs b/src/syscall_imp/fs/io.rs index 58be3956..4876b1cf 100644 --- a/src/syscall_imp/fs/io.rs +++ b/src/syscall_imp/fs/io.rs @@ -28,7 +28,7 @@ pub(crate) fn sys_openat( flags: i32, modes: mode_t, ) -> isize { - let path = syscall_unwrap!(path.get_as_cstr()); + let path = syscall_unwrap!(path.get_as_null_terminated()); api::sys_openat(dirfd, path.as_ptr(), flags, modes) as _ } diff --git a/src/syscall_imp/fs/stat.rs b/src/syscall_imp/fs/stat.rs index 61490c6a..ae09e9f4 100644 --- a/src/syscall_imp/fs/stat.rs +++ b/src/syscall_imp/fs/stat.rs @@ -100,7 +100,7 @@ pub fn sys_fstatat( _flags: i32, ) -> i32 { syscall_body!(sys_fstatat, { - let path = path.get_as_cstr()?; + let path = path.get_as_null_terminated()?; let path = arceos_posix_api::handle_file_path(dir_fd, Some(path.as_ptr() as _), false)?; let kstatbuf = kstatbuf.get()?; diff --git a/src/syscall_imp/mod.rs b/src/syscall_imp/mod.rs index 4b165e81..25334851 100644 --- a/src/syscall_imp/mod.rs +++ b/src/syscall_imp/mod.rs @@ -8,7 +8,7 @@ mod utils; use core::ffi::c_char; use crate::{ - ptr::{PtrWrapper, UserConstPtr}, + ptr::UserConstPtr, task::{time_stat_from_kernel_to_user, time_stat_from_user_to_kernel}, }; use axerrno::LinuxError; @@ -62,7 +62,7 @@ macro_rules! syscall_body { } pub(crate) fn read_path_str(path: UserConstPtr) -> Result<&'static str, LinuxError> { - path.get_as_cstr()?.to_str().map_err(|_| { + path.get_as_str()?.ok_or_else(|| { warn!("Invalid path"); LinuxError::EFAULT }) diff --git a/src/syscall_imp/task/thread.rs b/src/syscall_imp/task/thread.rs index 78a7ff9d..65d69c0a 100644 --- a/src/syscall_imp/task/thread.rs +++ b/src/syscall_imp/task/thread.rs @@ -1,4 +1,7 @@ -use core::ffi::{c_char, c_int}; +use core::{ + ffi::{c_char, c_int}, + ptr, +}; use axerrno::LinuxError; use axtask::{TaskExtRef, current, yield_now}; @@ -144,9 +147,9 @@ pub(crate) fn sys_clone( pub(crate) fn sys_wait4(pid: i32, exit_code_ptr: UserPtr, option: u32) -> isize { let option_flag = WaitFlags::from_bits(option).unwrap(); syscall_body!(sys_wait4, { - let exit_code_ptr = exit_code_ptr.get()?; + let exit_code_ptr = exit_code_ptr.nullable(UserPtr::get)?; loop { - let answer = wait_pid(pid, exit_code_ptr); + let answer = wait_pid(pid, exit_code_ptr.unwrap_or_else(ptr::null_mut)); match answer { Ok(pid) => { return Ok(pid as isize);