use std::any::Any;
use std::mem::MaybeUninit;
use std::ptr;
use crate::reg_context::RegContext;
thread_local!(
static ROOT_CONTEXT: Box<Context> = {
let mut root = Box::new(Context::new());
let p = &mut *root as *mut _;
root.parent = p; root
}
);
#[cfg(nightly)]
#[thread_local]
static mut ROOT_CONTEXT_P: *mut Context = ptr::null_mut();
#[allow(dead_code)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Error {
Done,
Cancel,
TypeErr,
StackErr,
ContextErr,
}
#[repr(C)]
#[repr(align(128))]
pub struct Context {
pub regs: RegContext,
child: *mut Context,
pub parent: *mut Context,
pub para: MaybeUninit<*mut dyn Any>,
pub ret: MaybeUninit<*mut dyn Any>,
pub _ref: usize,
pub local_data: *mut u8,
pub err: Option<Box<dyn Any + Send>>,
pub stack_guard: (usize, usize),
}
impl Context {
pub fn new() -> Context {
Context {
regs: RegContext::empty(),
para: MaybeUninit::zeroed(),
ret: MaybeUninit::zeroed(),
_ref: 1, err: None,
child: ptr::null_mut(),
parent: ptr::null_mut(),
local_data: ptr::null_mut(),
stack_guard: (0, 0),
}
}
#[inline]
pub fn is_generator(&self) -> bool {
self.parent != self as *const _ as *mut _
}
#[inline]
pub fn get_para<A>(&mut self) -> Option<A>
where
A: Any,
{
let para = unsafe {
let para_ptr = *self.para.as_mut_ptr();
assert!(!para_ptr.is_null());
&mut *para_ptr
};
match para.downcast_mut::<Option<A>>() {
Some(v) => v.take(),
None => type_error::<A>("get yield type mismatch error detected"),
}
}
#[inline]
pub fn co_get_para<A>(&mut self) -> Option<A> {
let para = unsafe {
let para_ptr = *self.para.as_mut_ptr();
debug_assert!(!para_ptr.is_null());
&mut *(para_ptr as *mut Option<A>)
};
para.take()
}
#[inline]
pub fn co_set_para<A>(&mut self, data: A) {
let para = unsafe {
let para_ptr = *self.para.as_mut_ptr();
debug_assert!(!para_ptr.is_null());
&mut *(para_ptr as *mut Option<A>)
};
*para = Some(data);
}
#[inline]
pub fn set_ret<T>(&mut self, v: T)
where
T: Any,
{
let ret = unsafe {
let ret_ptr = *self.ret.as_mut_ptr();
assert!(!ret_ptr.is_null());
&mut *ret_ptr
};
match ret.downcast_mut::<Option<T>>() {
Some(r) => *r = Some(v),
None => type_error::<T>("yield type mismatch error detected"),
}
}
#[inline]
pub fn co_set_ret<T>(&mut self, v: T) {
let ret = unsafe {
let ret_ptr = *self.ret.as_mut_ptr();
debug_assert!(!ret_ptr.is_null());
&mut *(ret_ptr as *mut Option<T>)
};
*ret = Some(v);
}
}
pub struct ContextStack {
pub(crate) root: *mut Context,
}
#[cfg(nightly)]
#[inline(never)]
unsafe fn init_root_p() {
ROOT_CONTEXT_P = ROOT_CONTEXT.with(|r| &**r as *const _ as *mut Context);
}
impl ContextStack {
#[cfg(nightly)]
#[inline(never)]
pub fn current() -> ContextStack {
unsafe {
if ROOT_CONTEXT_P.is_null() {
init_root_p();
}
ContextStack {
root: ROOT_CONTEXT_P,
}
}
}
#[cfg(not(nightly))]
#[inline(never)]
pub fn current() -> ContextStack {
let root = ROOT_CONTEXT.with(|r| &**r as *const _ as *mut Context);
ContextStack { root }
}
#[inline]
pub fn top(&self) -> &'static mut Context {
let root = unsafe { &mut *self.root };
unsafe { &mut *root.parent }
}
#[inline]
pub fn co_ctx(&self) -> Option<&'static mut Context> {
let root = unsafe { &mut *self.root };
let mut ctx = unsafe { &mut *root.parent };
while ctx as *const _ != root as *const _ {
if !ctx.local_data.is_null() {
return Some(ctx);
}
ctx = unsafe { &mut *ctx.parent };
}
None
}
#[inline]
pub fn push_context(&self, ctx: *mut Context) {
let root = unsafe { &mut *self.root };
let ctx = unsafe { &mut *ctx };
let top = unsafe { &mut *root.parent };
let new_top = ctx.parent;
top.child = ctx;
ctx.parent = top;
root.parent = new_top;
}
#[inline]
pub fn pop_context(&self, ctx: *mut Context) -> &'static mut Context {
let root = unsafe { &mut *self.root };
let ctx = unsafe { &mut *ctx };
let parent = unsafe { &mut *ctx.parent };
ctx.parent = root.parent;
parent.child = ptr::null_mut();
root.parent = parent;
parent
}
}
#[inline]
fn type_error<A>(msg: &str) -> ! {
error!("{msg}, expected type: {}", std::any::type_name::<A>());
std::panic::panic_any(Error::TypeErr)
}
#[inline]
pub fn is_generator() -> bool {
let env = ContextStack::current();
let root = unsafe { &mut *env.root };
!root.child.is_null()
}
#[inline]
pub fn get_local_data() -> *mut u8 {
let env = ContextStack::current();
let root = unsafe { &mut *env.root };
let mut ctx = unsafe { &mut *root.parent };
while ctx as *const _ != root as *const _ {
if !ctx.local_data.is_null() {
return ctx.local_data;
}
ctx = unsafe { &mut *ctx.parent };
}
ptr::null_mut()
}
pub mod guard {
use crate::is_generator;
use crate::rt::ContextStack;
use crate::stack::sys::page_size;
use std::ops::Range;
pub type Guard = Range<usize>;
pub fn current() -> Guard {
assert!(is_generator());
let guard = unsafe { (*(*ContextStack::current().root).child).stack_guard };
guard.0 - page_size()..guard.1
}
}
#[cfg(test)]
mod test {
use super::is_generator;
#[test]
fn test_is_context() {
assert!(!is_generator());
}
#[test]
fn test_overflow() {
use crate::*;
use std::panic::catch_unwind;
for _ in 0..2 {
let result = catch_unwind(|| {
let mut g = Gn::new_scoped(move |_s: Scope<(), ()>| {
let guard = super::guard::current();
std::hint::black_box(unsafe { *(guard.start as *const usize) });
eprintln!("entered unreachable code");
std::process::abort();
});
g.next();
});
assert!(matches!(
result.map_err(|err| *err.downcast::<Error>().unwrap()),
Err(Error::StackErr)
));
}
}
}