// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use core::{cell::Cell, fmt};

/// A datastructure that [memoizes](https://wikipedia.org/wiki/Memoization) a query function
///
/// This can be used for when queries rarely change and can potentially be expensive or on hot
/// code paths. After the `input` is mutated, the query value should be `clear`ed to signal that
/// the function needs to be executed again.
///
/// In debug mode the `get` call will always run the query and assert that the values match.
#[derive(Clone)]
pub struct Memo<T: Copy, Input, Check = DefaultConsistencyCheck> {
    value: Cell<Option<T>>,
    query: fn(&Input) -> T,
    check: Check,
}

impl<T: Copy + fmt::Debug, Input, Check> fmt::Debug for Memo<T, Input, Check> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("Memo").field(&self.value.get()).finish()
    }
}

impl<T: Copy + PartialEq + fmt::Debug, Input, Check: ConsistencyCheck> Memo<T, Input, Check> {
    /// Creates a new `Memo` over a query function
    #[inline]
    pub fn new(query: fn(&Input) -> T) -> Self {
        Self {
            value: Cell::new(None),
            query,
            check: Check::default(),
        }
    }

    /// Returns the current value of the query function, which may be cached
    #[inline]
    #[track_caller]
    pub fn get(&self, input: &Input) -> T {
        if let Some(value) = self.value.get() {
            // make sure the values match
            self.check.check_consistency(value, input, self.query);
            return value;
        }

        let value = (self.query)(input);
        self.value.set(Some(value));
        value
    }

    /// Clears the cached value of the query function
    #[inline]
    pub fn clear(&self) {
        self.value.set(None);
    }

    /// Asserts that the cached value reflects the current query result in debug mode
    #[inline]
    #[track_caller]
    pub fn check_consistency(&self, input: &Input) {
        if cfg!(debug_assertions) {
            // `get` will assert the value matches the query internally
            let _ = self.get(input);
        }
    }
}

/// Trait to configure consistency checking behavior
pub trait ConsistencyCheck: Clone + Default {
    /// Called when the `Memo` struct has a cached value
    ///
    /// An implementation can assert that the `cache` value matches the current `query` result
    fn check_consistency<T: PartialEq + fmt::Debug, Input>(
        &self,
        cache: T,
        input: &Input,
        query: fn(&Input) -> T,
    );
}

#[derive(Copy, Clone, Default)]
pub struct ConsistencyCheckAlways;

impl ConsistencyCheck for ConsistencyCheckAlways {
    #[inline]
    fn check_consistency<T: PartialEq + fmt::Debug, Input>(
        &self,
        actual: T,
        input: &Input,
        query: fn(&Input) -> T,
    ) {
        let expected = query(input);
        assert_eq!(expected, actual);
    }
}

#[derive(Copy, Clone, Default)]
pub struct ConsistencyCheckNever;

impl ConsistencyCheck for ConsistencyCheckNever {
    #[inline]
    fn check_consistency<T: PartialEq + fmt::Debug, Input>(
        &self,
        _cache: T,
        _input: &Input,
        _query: fn(&Input) -> T,
    ) {
        // noop
    }
}

#[cfg(debug_assertions)]
pub type DefaultConsistencyCheck = ConsistencyCheckAlways;
#[cfg(not(debug_assertions))]
pub type DefaultConsistencyCheck = ConsistencyCheckNever;

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Debug, Default)]
    struct Input<Value> {
        value: Value,
        should_query: bool,
    }

    #[test]
    fn memo_test() {
        let memo = Memo::<u64, Input<_>, ConsistencyCheckNever>::new(|input| {
            assert!(
                input.should_query,
                "query was called when it wasn't expected"
            );
            input.value
        });

        assert_eq!(
            memo.get(&Input {
                value: 1,
                should_query: true,
            }),
            1
        );

        assert_eq!(
            memo.get(&Input {
                value: 2,
                should_query: false,
            }),
            1
        );

        memo.clear();

        assert_eq!(
            memo.get(&Input {
                value: 3,
                should_query: true,
            }),
            3
        );

        assert_eq!(
            memo.get(&Input {
                value: 4,
                should_query: false,
            }),
            3
        );
    }
}