// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{
    contexts::OnTransmitError,
    sync::{IncrementalValueSync, ValueToFrameWriter},
    transmission,
    transmission::WriteContext,
};
use s2n_quic_core::{
    ack, frame::MaxStreams, packet::number::PacketNumber, stream::StreamId, transport,
    varint::VarInt,
};

//= https://www.rfc-editor.org/rfc/rfc9000#section-4.6
//# An endpoint MUST NOT wait
//# to receive this signal before advertising additional credit, since
//# doing so will mean that the peer will be blocked for at least an
//# entire round trip
// Send a MAX_STREAMS frame whenever 1/10th of the window has been closed
pub const MAX_STREAMS_SYNC_FRACTION: VarInt = VarInt::from_u8(10);
//= https://www.rfc-editor.org/rfc/rfc9000#section-19.11
//# Maximum Streams:  A count of the cumulative number of streams of the
//# corresponding type that can be opened over the lifetime of the
//# connection.  This value cannot exceed 2^60, as it is not possible
//# to encode stream IDs larger than 2^62-1.
// Safety: 2^60 is less than MAX_VARINT_VALUE
const MAX_STREAMS_MAX_VALUE: VarInt = unsafe { VarInt::new_unchecked(1 << 60) };

/// The RemoteInitiated controller controls streams initiated by the peer
#[derive(Debug)]
pub(super) struct RemoteInitiated {
    /// The max stream limit specified by the local endpoint.
    ///
    /// Used to calculate updated max_streams_sync value as the peer
    /// closes streams.
    max_local_limit: VarInt,
    /// Responsible for advertising updated max stream frames as the
    /// peer closes streams
    max_streams_sync: IncrementalValueSync<VarInt, MaxStreamsToFrameWriter>,
    opened_streams: VarInt,
    closed_streams: VarInt,
}

impl RemoteInitiated {
    pub fn new(max_local_limit: VarInt) -> Self {
        Self {
            max_local_limit,
            max_streams_sync: IncrementalValueSync::new(
                max_local_limit,
                max_local_limit,
                max_local_limit / MAX_STREAMS_SYNC_FRACTION,
            ),
            opened_streams: VarInt::from_u8(0),
            closed_streams: VarInt::from_u8(0),
        }
    }

    pub fn on_remote_open_stream(&mut self, stream_id: StreamId) -> Result<(), transport::Error> {
        // get the total number of streams that are allowed
        let max_allowed_stream_limit = self.max_streams_sync.latest_value().as_u64();

        // since streams are 0-indexed, using `max_allowed_stream_limit` to calculate
        // the stream_id gives 1 stream_id greater than the allowed limit
        let not_allowed_stream_id = StreamId::nth(
            stream_id.initiator(),
            stream_id.stream_type(),
            max_allowed_stream_limit,
        )
        .expect("max_streams is limited to MAX_STREAMS_MAX_VALUE");

        if stream_id >= not_allowed_stream_id {
            //= https://www.rfc-editor.org/rfc/rfc9000#section-4.6
            //# Endpoints MUST NOT exceed the limit set by their peer.  An endpoint
            //# that receives a frame with a stream ID exceeding the limit it has
            //# sent MUST treat this as a connection error of type
            //# STREAM_LIMIT_ERROR; see Section 11 for details on error handling.

            //= https://www.rfc-editor.org/rfc/rfc9000#section-19.11
            //# An endpoint MUST terminate a connection
            //# with an error of type STREAM_LIMIT_ERROR if a peer opens more streams
            //# than was permitted.
            return Err(transport::Error::STREAM_LIMIT_ERROR);
        }
        Ok(())
    }

    #[inline]
    pub fn on_open_stream(&mut self) {
        self.opened_streams += 1;

        self.check_integrity();
    }

    pub fn on_close_stream(&mut self) {
        self.closed_streams += 1;

        let max_streams = self
            .closed_streams
            .saturating_add(self.max_local_limit)
            .min(MAX_STREAMS_MAX_VALUE);
        self.max_streams_sync.update_latest_value(max_streams);

        self.check_integrity();
    }

    /// Returns the number of streams currently open
    #[inline]
    pub fn open_stream_count(&self) -> VarInt {
        self.opened_streams - self.closed_streams
    }

    #[inline]
    pub fn total_open_stream_count(&self) -> VarInt {
        self.opened_streams
    }

    #[inline]
    pub fn on_packet_ack<A: ack::Set>(&mut self, ack_set: &A) {
        self.max_streams_sync.on_packet_ack(ack_set)
    }

    #[inline]
    pub fn on_packet_loss<A: ack::Set>(&mut self, ack_set: &A) {
        self.max_streams_sync.on_packet_loss(ack_set)
    }

    #[inline]
    pub fn on_transmit<W: WriteContext>(
        &mut self,
        stream_id: StreamId,
        context: &mut W,
    ) -> Result<(), OnTransmitError> {
        self.max_streams_sync.on_transmit(stream_id, context)
    }

    pub fn close(&mut self) {
        self.max_streams_sync.stop_sync();
    }

    #[inline]
    fn check_integrity(&self) {
        if cfg!(debug_assertions) {
            assert!(
                self.closed_streams <= self.opened_streams,
                "Cannot close more streams than previously opened"
            );
            assert!(
                self.open_stream_count() <= self.max_local_limit,
                "Cannot have more incoming streams open concurrently than
                the max_local_limit"
            );
        }
    }

    #[cfg(test)]
    pub fn latest_limit(&self) -> VarInt {
        self.max_streams_sync.latest_value()
    }
}

impl transmission::interest::Provider for RemoteInitiated {
    #[inline]
    fn transmission_interest<Q: transmission::interest::Query>(
        &self,
        query: &mut Q,
    ) -> transmission::interest::Result {
        self.max_streams_sync.transmission_interest(query)
    }
}

/// Writes the `MAX_STREAMS` frames based on the stream control window.
#[derive(Debug, Default)]
pub(super) struct MaxStreamsToFrameWriter {}

impl ValueToFrameWriter<VarInt> for MaxStreamsToFrameWriter {
    fn write_value_as_frame<W: WriteContext>(
        &self,
        value: VarInt,
        stream_id: StreamId,
        context: &mut W,
    ) -> Option<PacketNumber> {
        context.write_frame(&MaxStreams {
            stream_type: stream_id.stream_type(),
            maximum_streams: value,
        })
    }
}