/src/capabilities.rs
use core::cell::RefCell;
use core::fmt;

use critical_section::Mutex;

use crate::task::TaskId;

pub const MAX_CAPS: usize = 4;

static CAP_REGISTRY: Mutex<RefCell<Option<CapRegistry>>> = Mutex::new(RefCell::new(None));

#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub enum CapType {
    ConsoleRead = 0,
    ConsoleWrite = 1,
    Led = 2,
    Keyboard = 3,
}

impl fmt::Display for CapType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            CapType::ConsoleRead => write!(f, "ConRead"),
            CapType::ConsoleWrite => write!(f, "ConWrite"),
            CapType::Led => write!(f, "LED"),
            CapType::Keyboard => write!(f, "Keyboard"),
        }
    }
}

pub struct CapSet(u32);

impl CapSet {
    pub fn new(caps: &[CapType]) -> CapSet {
        let mut v = 0;
        for c in caps {
            v |= 1 << (*c as usize);
        }
        CapSet(v)
    }

    pub fn has(&self, cap: CapType) -> bool {
        self.0 & (1 << (cap as usize)) != 0
    }

    pub fn add(&mut self, caps: &[CapType]) -> Result<&mut CapSet, ()> {
        for c in caps {
            if self.has(*c) {
                return Err(());
            }
            self.0 |= 1 << (*c as usize);
        }
        Ok(self)
    }

    pub fn remove(&mut self, caps: &[CapType]) -> Result<&mut CapSet, ()> {
        for c in caps {
            if !self.has(*c) {
                return Err(());
            }
            self.0 &= !(1 << (*c as usize));
        }
        Ok(self)
    }
}

#[derive(Debug)]
pub struct CapToken(CapType);
impl CapToken {
    pub fn captype(&self) -> CapType {
        self.0
    }
}

#[derive(Debug)]
enum CapState {
    Available(CapToken),
    Taken(TaskId),
}

impl CapState {
    fn available(&self) -> bool {
        match self {
            CapState::Available(_) => true,
            CapState::Taken(_) => false,
        }
    }

    fn take(&mut self, tid: TaskId) -> Option<CapToken> {
        if let CapState::Available(_) = self {
            let old = core::mem::replace(self, CapState::Taken(tid));
            let CapState::Available(t) = old else {
                unreachable!();
            };
            Some(t)
        } else {
            None
        }
    }

    fn replace(&mut self, token: CapToken, tid: TaskId) -> bool {
        if let CapState::Taken(tid2) = self {
            if tid == *tid2 {
                *self = CapState::Available(token);
                true
            } else {
                false
            }
        } else {
            false
        }
    }
}

#[derive(Debug)]
struct CapEntry {
    cap: CapType,
    state: CapState,
}

impl CapEntry {
    fn new(cap: CapType) -> CapEntry {
        CapEntry {
            cap,
            state: CapState::Available(CapToken(cap)),
        }
    }

    fn take(&mut self, tid: TaskId) -> Option<CapToken> {
        self.state.take(tid)
    }

    fn give(&mut self, token: CapToken, tid: TaskId) -> bool {
        self.state.replace(token, tid)
    }
}

#[derive(Debug)]
pub struct CapRegistry {
    capabilities: heapless::Vec<CapEntry, 10>,
}

impl CapRegistry {
    pub fn new() -> CapRegistry {
        let mut capabilities = heapless::Vec::new();
        capabilities
            .push(CapEntry::new(CapType::ConsoleRead))
            .unwrap();
        capabilities
            .push(CapEntry::new(CapType::ConsoleWrite))
            .unwrap();
        capabilities.push(CapEntry::new(CapType::Led)).unwrap();
        capabilities.push(CapEntry::new(CapType::Keyboard)).unwrap();
        CapRegistry { capabilities }
    }

    pub fn available(&self, cap: CapType) -> bool {
        let cap_entry = self.capabilities.iter().find(|ce| ce.cap == cap && ce.state.available());
        cap_entry.is_some()
    }

    pub fn take(&mut self, cap: &[CapType], tid: TaskId) -> Option<heapless::Vec<CapToken, MAX_CAPS>> {
        // first, check if we have the capabilities available
        let available = cap.iter().all(|cap| self.available(*cap));
        if !available {
            return None;
        }

        // Then take all the tokens
        cap.iter().map(|c| {
            let cap_entry = self.capabilities.iter_mut().find(|ce| ce.cap == *c).unwrap();
            cap_entry.take(tid)
        }).collect()
    }

    pub fn give(&mut self, token: CapToken, tid: TaskId) -> bool {
        let cap_entry = self.capabilities.iter_mut().find(|ce| ce.cap == token.0);
        if let Some(cap_entry) = cap_entry {
            cap_entry.give(token, tid)
        } else {
            false
        }
    }
}

pub fn init_capabilities() {
    critical_section::with(|cs| {
        let mut cap_registry = CAP_REGISTRY.borrow_ref_mut(cs);
        *cap_registry = Some(CapRegistry::new());
    })
}

pub fn with_cap_registry<F, R>(mut f: F) -> R
where
    F: FnMut(&mut CapRegistry) -> R,
{
    critical_section::with(|cs| {
        let mut cap_registry = CAP_REGISTRY.borrow_ref_mut(cs);
        if let Some(ref mut cap_registry) = *cap_registry {
            f(cap_registry)
        } else {
            panic!("Capabilities registry not initialized");
        }
    })
}