// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use bytes::{Buf, Bytes};
use futures::ready;
use h3::quic::{self, Error, StreamId, WriteBuf};
use s2n_quic::stream::{BidirectionalStream, ReceiveStream};
use s2n_quic_core::varint::VarInt;
use std::{
    convert::TryInto,
    fmt::{self, Display},
    sync::Arc,
    task::{self, Poll},
};

pub struct Connection {
    conn: s2n_quic::connection::Handle,
    bidi_acceptor: s2n_quic::connection::BidirectionalStreamAcceptor,
    recv_acceptor: s2n_quic::connection::ReceiveStreamAcceptor,
}

impl Connection {
    pub fn new(new_conn: s2n_quic::Connection) -> Self {
        let (handle, acceptor) = new_conn.split();
        let (bidi, recv) = acceptor.split();

        Self {
            conn: handle,
            bidi_acceptor: bidi,
            recv_acceptor: recv,
        }
    }
}

#[derive(Debug)]
pub struct ConnectionError(s2n_quic::connection::Error);

impl std::error::Error for ConnectionError {}

impl fmt::Display for ConnectionError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(f)
    }
}

impl Error for ConnectionError {
    fn is_timeout(&self) -> bool {
        matches!(self.0, s2n_quic::connection::Error::IdleTimerExpired { .. })
    }

    fn err_code(&self) -> Option<u64> {
        match self.0 {
            s2n_quic::connection::Error::Application { error, .. } => Some(error.into()),
            _ => None,
        }
    }
}

impl From<s2n_quic::connection::Error> for ConnectionError {
    fn from(e: s2n_quic::connection::Error) -> Self {
        Self(e)
    }
}

impl<B> quic::Connection<B> for Connection
where
    B: Buf,
{
    type BidiStream = BidiStream<B>;
    type SendStream = SendStream<B>;
    type RecvStream = RecvStream;
    type OpenStreams = OpenStreams;
    type Error = ConnectionError;

    fn poll_accept_recv(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> {
        let recv = match ready!(self.recv_acceptor.poll_accept_receive_stream(cx))? {
            Some(x) => x,
            None => return Poll::Ready(Ok(None)),
        };
        Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
    }

    fn poll_accept_bidi(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> {
        let (recv, send) = match ready!(self.bidi_acceptor.poll_accept_bidirectional_stream(cx))? {
            Some(x) => x.split(),
            None => return Poll::Ready(Ok(None)),
        };
        Poll::Ready(Ok(Some(Self::BidiStream {
            send: Self::SendStream::new(send),
            recv: Self::RecvStream::new(recv),
        })))
    }

    fn poll_open_bidi(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Self::BidiStream, Self::Error>> {
        let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?;
        Ok(stream.into()).into()
    }

    fn poll_open_send(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Self::SendStream, Self::Error>> {
        let stream = ready!(self.conn.poll_open_send_stream(cx))?;
        Ok(stream.into()).into()
    }

    fn opener(&self) -> Self::OpenStreams {
        OpenStreams {
            conn: self.conn.clone(),
        }
    }

    fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
        self.conn.close(
            code.value()
                .try_into()
                .expect("s2n-quic supports error codes up to 2^62-1"),
        );
    }
}

pub struct OpenStreams {
    conn: s2n_quic::connection::Handle,
}

impl<B> quic::OpenStreams<B> for OpenStreams
where
    B: Buf,
{
    type BidiStream = BidiStream<B>;
    type SendStream = SendStream<B>;
    type RecvStream = RecvStream;
    type Error = ConnectionError;

    fn poll_open_bidi(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Self::BidiStream, Self::Error>> {
        let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?;
        Ok(stream.into()).into()
    }

    fn poll_open_send(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Self::SendStream, Self::Error>> {
        let stream = ready!(self.conn.poll_open_send_stream(cx))?;
        Ok(stream.into()).into()
    }

    fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
        self.conn.close(
            code.value()
                .try_into()
                .unwrap_or_else(|_| VarInt::MAX.into()),
        );
    }
}

impl Clone for OpenStreams {
    fn clone(&self) -> Self {
        Self {
            conn: self.conn.clone(),
        }
    }
}

pub struct BidiStream<B>
where
    B: Buf,
{
    send: SendStream<B>,
    recv: RecvStream,
}

impl<B> quic::BidiStream<B> for BidiStream<B>
where
    B: Buf,
{
    type SendStream = SendStream<B>;
    type RecvStream = RecvStream;

    fn split(self) -> (Self::SendStream, Self::RecvStream) {
        (self.send, self.recv)
    }
}

impl<B> quic::RecvStream for BidiStream<B>
where
    B: Buf,
{
    type Buf = Bytes;
    type Error = ReadError;

    fn poll_data(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
        self.recv.poll_data(cx)
    }

    fn stop_sending(&mut self, error_code: u64) {
        self.recv.stop_sending(error_code)
    }
}

impl<B> quic::SendStream<B> for BidiStream<B>
where
    B: Buf,
{
    type Error = SendStreamError;

    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.send.poll_ready(cx)
    }

    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.send.poll_finish(cx)
    }

    fn reset(&mut self, reset_code: u64) {
        self.send.reset(reset_code)
    }

    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
        self.send.send_data(data)
    }

    fn id(&self) -> StreamId {
        self.send.id()
    }
}

impl<B> From<BidirectionalStream> for BidiStream<B>
where
    B: Buf,
{
    fn from(bidi: BidirectionalStream) -> Self {
        let (recv, send) = bidi.split();
        BidiStream {
            send: send.into(),
            recv: recv.into(),
        }
    }
}

pub struct RecvStream {
    stream: s2n_quic::stream::ReceiveStream,
}

impl RecvStream {
    fn new(stream: s2n_quic::stream::ReceiveStream) -> Self {
        Self { stream }
    }
}

impl quic::RecvStream for RecvStream {
    type Buf = Bytes;
    type Error = ReadError;

    fn poll_data(
        &mut self,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
        let buf = ready!(self.stream.poll_receive(cx))?;
        Ok(buf).into()
    }

    fn stop_sending(&mut self, error_code: u64) {
        let _ = self.stream.stop_sending(
            s2n_quic::application::Error::new(error_code)
                .expect("s2n-quic supports error codes up to 2^62-1"),
        );
    }
}

impl From<ReceiveStream> for RecvStream {
    fn from(recv: ReceiveStream) -> Self {
        RecvStream::new(recv)
    }
}

#[derive(Debug)]
pub struct ReadError(s2n_quic::stream::Error);

impl std::error::Error for ReadError {}

impl fmt::Display for ReadError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(f)
    }
}

impl From<ReadError> for Arc<dyn Error> {
    fn from(e: ReadError) -> Self {
        Arc::new(e)
    }
}

impl From<s2n_quic::stream::Error> for ReadError {
    fn from(e: s2n_quic::stream::Error) -> Self {
        Self(e)
    }
}

impl Error for ReadError {
    fn is_timeout(&self) -> bool {
        matches!(
            self.0,
            s2n_quic::stream::Error::ConnectionError {
                error: s2n_quic::connection::Error::IdleTimerExpired { .. },
                ..
            }
        )
    }

    fn err_code(&self) -> Option<u64> {
        match self.0 {
            s2n_quic::stream::Error::ConnectionError {
                error: s2n_quic::connection::Error::Application { error, .. },
                ..
            } => Some(error.into()),
            s2n_quic::stream::Error::StreamReset { error, .. } => Some(error.into()),
            _ => None,
        }
    }
}

pub struct SendStream<B: Buf> {
    stream: s2n_quic::stream::SendStream,
    chunk: Option<Bytes>,
    buf: Option<WriteBuf<B>>, // TODO: Replace with buf: PhantomData<B>
                              //       after https://github.com/hyperium/h3/issues/78 is resolved
}

impl<B> SendStream<B>
where
    B: Buf,
{
    fn new(stream: s2n_quic::stream::SendStream) -> SendStream<B> {
        Self {
            stream,
            chunk: None,
            buf: Default::default(),
        }
    }
}

impl<B> quic::SendStream<B> for SendStream<B>
where
    B: Buf,
{
    type Error = SendStreamError;

    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
        loop {
            // try to flush the current chunk if we have one
            if let Some(chunk) = self.chunk.as_mut() {
                ready!(self.stream.poll_send(chunk, cx))?;

                // s2n-quic will take the whole chunk on send, even if it exceeds the limits
                debug_assert!(chunk.is_empty());
                self.chunk = None;
            }

            // try to take the next chunk from the WriteBuf
            if let Some(ref mut data) = self.buf {
                let len = data.chunk().len();

                // if the write buf is empty, then clear it and break
                if len == 0 {
                    self.buf = None;
                    break;
                }

                // copy the first chunk from WriteBuf and prepare it to flush
                let chunk = data.copy_to_bytes(len);
                self.chunk = Some(chunk);

                // loop back around to flush the chunk
                continue;
            }

            // if we didn't have either a chunk or WriteBuf, then we're ready
            break;
        }

        Poll::Ready(Ok(()))

        // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved
        // self.available_bytes = ready!(self.stream.poll_send_ready(cx))?;
        // Poll::Ready(Ok(()))
    }

    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
        if self.buf.is_some() {
            return Err(Self::Error::NotReady);
        }
        self.buf = Some(data.into());
        Ok(())

        // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved
        // let mut data = data.into();
        // while self.available_bytes > 0 && data.has_remaining() {
        //     let len = data.chunk().len();
        //     let chunk = data.copy_to_bytes(len);
        //     self.stream.send_data(chunk)?;
        //     self.available_bytes = self.available_bytes.saturating_sub(len);
        // }
        // Ok(())
    }

    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
        // ensure all chunks are flushed to the QUIC stream before finishing
        ready!(self.poll_ready(cx))?;
        self.stream.finish()?;
        Ok(()).into()
    }

    fn reset(&mut self, reset_code: u64) {
        let _ = self
            .stream
            .reset(reset_code.try_into().unwrap_or_else(|_| VarInt::MAX.into()));
    }

    fn id(&self) -> StreamId {
        self.stream.id().try_into().expect("invalid stream id")
    }
}

impl<B> From<s2n_quic::stream::SendStream> for SendStream<B>
where
    B: Buf,
{
    fn from(send: s2n_quic::stream::SendStream) -> Self {
        SendStream::new(send)
    }
}

#[derive(Debug)]
pub enum SendStreamError {
    Write(s2n_quic::stream::Error),
    NotReady,
}

impl std::error::Error for SendStreamError {}

impl Display for SendStreamError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{self:?}")
    }
}

impl From<s2n_quic::stream::Error> for SendStreamError {
    fn from(e: s2n_quic::stream::Error) -> Self {
        Self::Write(e)
    }
}

impl Error for SendStreamError {
    fn is_timeout(&self) -> bool {
        matches!(
            self,
            Self::Write(s2n_quic::stream::Error::ConnectionError {
                error: s2n_quic::connection::Error::IdleTimerExpired { .. },
                ..
            })
        )
    }

    fn err_code(&self) -> Option<u64> {
        match self {
            Self::Write(s2n_quic::stream::Error::StreamReset { error, .. }) => {
                Some((*error).into())
            }
            Self::Write(s2n_quic::stream::Error::ConnectionError {
                error: s2n_quic::connection::Error::Application { error, .. },
                ..
            }) => Some((*error).into()),
            _ => None,
        }
    }
}

impl From<SendStreamError> for Arc<dyn Error> {
    fn from(e: SendStreamError) -> Self {
        Arc::new(e)
    }
}