diff --git a/src/context/list.rs b/src/context/list.rs index d56b0e2..a1a2895 100644 --- a/src/context/list.rs +++ b/src/context/list.rs @@ -47,6 +47,9 @@ impl ContextList { pub fn iter(&self) -> ::alloc::collections::btree_map::Iter>> { self.map.iter() } + pub fn range(&self, range: impl core::ops::RangeBounds) -> ::alloc::collections::btree_map::Range<'_, ContextId, Arc>> { + self.map.range(range) + } /// Create a new context. pub fn new_context(&mut self) -> Result<&Arc>> { diff --git a/src/context/switch.rs b/src/context/switch.rs index 877033d..75d97b6 100644 --- a/src/context/switch.rs +++ b/src/context/switch.rs @@ -1,5 +1,8 @@ +use core::ops::Bound; use core::sync::atomic::Ordering; +use alloc::sync::Arc; + use crate::context::signal::signal_handler; use crate::context::{arch, contexts, Context, Status, CONTEXT_ID}; use crate::gdt; @@ -72,8 +75,6 @@ unsafe fn runnable(context: &Context, cpu_id: usize) -> bool { /// /// Do not call this while holding locks! pub unsafe fn switch() -> bool { - use core::ops::DerefMut; - //set PIT Interrupt counter to 0, giving each process same amount of PIT ticks let ticks = PIT_TICKS.swap(0, Ordering::SeqCst); @@ -84,93 +85,108 @@ pub unsafe fn switch() -> bool { let cpu_id = crate::cpu_id(); - let from_ptr; - let mut to_ptr = 0 as *mut Context; + let from_context_lock; + let mut from_context_guard; + let mut to_context_lock: Option<(Arc>, *mut Context)> = None; let mut to_sig = None; { let contexts = contexts(); { - let context_lock = contexts + from_context_lock = Arc::clone(contexts .current() - .expect("context::switch: not inside of context"); - let mut context = context_lock.write(); - context.ticks += ticks as u64 + 1; // Always round ticks up - from_ptr = context.deref_mut() as *mut Context; - } - - macro_rules! to { - ($context:expr) => {{ - let context: &mut Context = $context; - if runnable(context, cpu_id) { - to_ptr = context as *mut Context; - if context.ksig.is_none() { - to_sig = context.pending.pop_front(); - } - true - } else { - false - } - }}; - }; - - for (_pid, context_lock) in contexts.iter() { - let mut context = context_lock.write(); - update(&mut context, cpu_id); + .expect("context::switch: not inside of context")); + from_context_guard = from_context_lock.write(); + from_context_guard.ticks += ticks as u64 + 1; // Always round ticks up } for (pid, context_lock) in contexts.iter() { - if *pid > (*from_ptr).id { - let mut context = context_lock.write(); - if to!(&mut context) { - break; - } - } + let mut context; + let context_ref = if *pid == from_context_guard.id { + &mut *from_context_guard + } else { + context = context_lock.write(); + &mut *context + }; + update(context_ref, cpu_id); } - if to_ptr as usize == 0 { - for (pid, context_lock) in contexts.iter() { - if *pid < (*from_ptr).id { - let mut context = context_lock.write(); - if to!(&mut context) { - break; - } + for (_pid, context_lock) in contexts + // Include all contexts with IDs greater than the current... + .range( + (Bound::Excluded(from_context_guard.id), Bound::Unbounded) + ) + .chain(contexts + // ... and all contexts with IDs less than the current... + .range((Bound::Unbounded, Bound::Excluded(from_context_guard.id))) + ) + // ... but not the current context, which is already locked + { + let context_lock = Arc::clone(context_lock); + let mut to_context_guard = context_lock.write(); + + if runnable(&*to_context_guard, cpu_id) { + if to_context_guard.ksig.is_none() { + to_sig = to_context_guard.pending.pop_front(); } + let ptr: *mut Context = &mut *to_context_guard; + core::mem::forget(to_context_guard); + to_context_lock = Some((context_lock, ptr)); + break; + } else { + continue; } } }; // Switch process states, TSS stack pointer, and store new context ID - if to_ptr as usize != 0 { - (*from_ptr).running = false; - (*to_ptr).running = true; - if let Some(ref stack) = (*to_ptr).kstack { + if let Some((to_context_lock, to_ptr)) = to_context_lock { + let to_context: &mut Context = &mut *to_ptr; + + from_context_guard.running = false; + to_context.running = true; + if let Some(ref stack) = to_context.kstack { gdt::set_tss_stack(stack.as_ptr() as usize + stack.len()); } - gdt::set_tcb((*to_ptr).id.into()); - CONTEXT_ID.store((*to_ptr).id, Ordering::SeqCst); - } + gdt::set_tcb(to_context.id.into()); + CONTEXT_ID.store(to_context.id, Ordering::SeqCst); - if to_ptr as usize == 0 { - // No target was found, unset global lock and return - arch::CONTEXT_SWITCH_LOCK.store(false, Ordering::SeqCst); - - false - } else { if let Some(sig) = to_sig { // Signal was found, run signal handler //TODO: Allow nested signals - assert!((*to_ptr).ksig.is_none()); + assert!(to_context.ksig.is_none()); - let arch = (*to_ptr).arch.clone(); - let kfx = (*to_ptr).kfx.clone(); - let kstack = (*to_ptr).kstack.clone(); - (*to_ptr).ksig = Some((arch, kfx, kstack, sig)); - (*to_ptr).arch.signal_stack(signal_handler, sig); + let arch = to_context.arch.clone(); + let kfx = to_context.kfx.clone(); + let kstack = to_context.kstack.clone(); + to_context.ksig = Some((arch, kfx, kstack, sig)); + to_context.arch.signal_stack(signal_handler, sig); } + let from_ptr: *mut Context = &mut *from_context_guard; + let to_ptr: *mut Context = &mut *to_context; + + // FIXME: Ensure that this critical section is somehow still protected by the lock, and not + // just for other processes' context switching, but for other operations which do not + // require CONTEXT_SWITCH_LOCK. + // + // What I suggest, is that we wrap the Context struct (typically stored as + // `Arc>`), into a wrapper with interior locking for the inner context type + // (which could be something like `RwLock`). The wrapper would also contain + // a field of type `UnsafeCell`, which would be accessible if and only if the + // `running` field is set to false, making that field function as a lock. + drop(from_context_guard); + drop(from_context_lock); + to_context_lock.force_write_unlock(); + drop(to_context_lock); + (*from_ptr).arch.switch_to(&mut (*to_ptr).arch); true + } else { + // No target was found, unset global lock and return + arch::CONTEXT_SWITCH_LOCK.store(false, Ordering::SeqCst); + + false } } diff --git a/src/syscall/process.rs b/src/syscall/process.rs index dfed1b1..b10e00b 100644 --- a/src/syscall/process.rs +++ b/src/syscall/process.rs @@ -46,7 +46,7 @@ pub fn clone(flags: CloneFlags, stack_base: usize) -> Result { let ens; let umask; let sigmask; - let cpu_id_opt = None; + let mut cpu_id_opt = None; let arch; let vfork; let mut kfx_opt = None;