// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use core::{cell::UnsafeCell, mem::MaybeUninit, ops::Deref};

#[repr(transparent)]
pub struct Cell<T>(MaybeUninit<UnsafeCell<T>>);

impl<T> Cell<T> {
    #[inline]
    pub unsafe fn write(&self, value: T) {
        UnsafeCell::raw_get(self.0.as_ptr()).write(value);
    }

    #[inline]
    pub unsafe fn take(&self) -> T {
        self.0.assume_init_ref().get().read()
    }
}

#[derive(Debug)]
pub struct Slice<'a, T>(pub(super) &'a [T]);

impl<'a, T> Slice<'a, Cell<T>> {
    /// Assumes that the slice of [`Cell`]s is initialized and converts it to a slice of
    /// [`UnsafeCell`]s.
    ///
    /// See [`core::mem::MaybeUninit::assume_init`]
    #[inline]
    pub unsafe fn assume_init(self) -> Slice<'a, UnsafeCell<T>> {
        Slice(&*(self.0 as *const [Cell<T>] as *const [UnsafeCell<T>]))
    }

    /// Writes a value into a cell at the provided index
    ///
    /// # Safety
    ///
    /// The cell at `index` must be uninitialized and the caller must have synchronized access.
    #[inline]
    pub unsafe fn write(&self, index: usize, value: T) {
        self.0.get_unchecked(index).write(value)
    }

    /// Reads and takes the memory at a cell at the provided index
    ///
    /// # Safety
    ///
    /// The cell at `index` must be initialized and the caller must have synchronized access.
    #[inline]
    pub unsafe fn take(&self, index: usize) -> T {
        self.0.get_unchecked(index).take()
    }
}

impl<'a, T> Slice<'a, UnsafeCell<T>> {
    /// Converts the slice of [`UnsafeCell`]s into a mutable slice
    ///
    /// # Safety
    ///
    /// The slice must be exclusively owned, otherwise data races may occur.
    #[inline]
    pub unsafe fn into_mut(self) -> &'a mut [T] {
        let ptr = self.0.as_ptr() as *mut T;
        let len = self.0.len();
        core::slice::from_raw_parts_mut(ptr, len)
    }
}

impl<'a, T> Deref for Slice<'a, T> {
    type Target = [T];

    #[inline]
    fn deref(&self) -> &[T] {
        self.0
    }
}

impl<'a, T: PartialEq> PartialEq<[T]> for Slice<'a, UnsafeCell<T>> {
    #[inline]
    fn eq(&self, other: &[T]) -> bool {
        if self.len() != other.len() {
            return false;
        }

        for (a, b) in self.iter().zip(other) {
            if unsafe { &*a.get() } != b {
                return false;
            }
        }

        true
    }
}

impl<'a, T: PartialEq> PartialEq<Slice<'a, UnsafeCell<T>>> for [T] {
    #[inline]
    fn eq(&self, other: &Slice<'a, UnsafeCell<T>>) -> bool {
        other.eq(self)
    }
}

impl<'a, T: PartialEq> PartialEq<Slice<'a, UnsafeCell<T>>> for &[T] {
    #[inline]
    fn eq(&self, other: &Slice<'a, UnsafeCell<T>>) -> bool {
        other.eq(self)
    }
}

#[derive(Debug)]
pub struct Pair<S> {
    pub head: S,
    pub tail: S,
}

impl<'a, T> Pair<Slice<'a, Cell<T>>> {
    #[inline]
    pub unsafe fn assume_init(self) -> Pair<Slice<'a, UnsafeCell<T>>> {
        Pair {
            head: self.head.assume_init(),
            tail: self.tail.assume_init(),
        }
    }

    #[inline]
    pub unsafe fn write(&self, index: usize, value: T) {
        self.cell(index).write(value)
    }

    #[inline]
    pub unsafe fn take(&self, index: usize) -> T {
        self.cell(index).take()
    }

    unsafe fn cell(&self, index: usize) -> &Cell<T> {
        if let Some(cell) = self.head.0.get(index) {
            cell
        } else {
            assume!(
                index >= self.head.0.len(),
                "index must always be equal or greater than the `head` len"
            );
            let index = index - self.head.0.len();

            assume!(
                self.tail.get(index).is_some(),
                "index must be in-bounds for the `tail` slice: head={}, tail={}, index={}",
                self.head.0.len(),
                self.tail.0.len(),
                index
            );
            self.tail.get_unchecked(index)
        }
    }

    #[inline]
    pub fn iter(&self) -> impl Iterator<Item = &Cell<T>> {
        self.head.0.iter().chain(self.tail.0)
    }

    #[inline]
    pub fn len(&self) -> usize {
        self.head.len() + self.tail.len()
    }
}

impl<'a, T> Pair<Slice<'a, UnsafeCell<T>>> {
    #[inline]
    pub unsafe fn into_mut(self) -> (&'a mut [T], &'a mut [T]) {
        let head = self.head.into_mut();
        let tail = self.tail.into_mut();
        (head, tail)
    }
}