// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{features::gso, message::default as message, socket, syscall};
use s2n_quic_core::{
    endpoint::Endpoint,
    event::{self, EndpointPublisher as _},
    inet::{self, SocketAddress},
    io::event_loop::EventLoop,
    path::MaxMtu,
    time::Clock as ClockTrait,
};
use std::{convert::TryInto, io, io::ErrorKind};
use tokio::runtime::Handle;

mod builder;
mod clock;
pub(crate) mod task;
#[cfg(test)]
mod tests;

pub type PathHandle = message::Handle;
pub use builder::Builder;
pub(crate) use clock::Clock;

#[derive(Debug, Default)]
pub struct Io {
    builder: Builder,
}

impl Io {
    pub fn builder() -> Builder {
        Builder::default()
    }

    pub fn new<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
        let address = addr.to_socket_addrs()?.next().expect("missing address");
        let builder = Builder::default().with_receive_address(address)?;
        Ok(Self { builder })
    }

    pub fn start<E: Endpoint<PathHandle = PathHandle>>(
        self,
        mut endpoint: E,
    ) -> io::Result<(tokio::task::JoinHandle<()>, SocketAddress)> {
        let Builder {
            handle,
            rx_socket,
            tx_socket,
            recv_addr,
            send_addr,
            socket_recv_buffer_size,
            socket_send_buffer_size,
            queue_recv_buffer_size,
            queue_send_buffer_size,
            mut max_mtu,
            max_segments,
            gro_enabled,
            reuse_port,
        } = self.builder;

        let clock = Clock::default();

        let mut publisher = event::EndpointPublisherSubscriber::new(
            event::builder::EndpointMeta {
                endpoint_type: E::ENDPOINT_TYPE,
                timestamp: clock.get_time(),
            },
            None,
            endpoint.subscriber(),
        );

        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
            configuration: event::builder::PlatformFeatureConfiguration::Gso {
                max_segments: max_segments.into(),
            },
        });

        // try to use the tokio runtime handle if provided, otherwise try to use the implicit tokio
        // runtime in the current scope of the application.
        let handle = if let Some(handle) = handle {
            handle
        } else {
            Handle::try_current().map_err(|err| std::io::Error::new(io::ErrorKind::Other, err))?
        };

        let guard = handle.enter();

        let rx_socket = if let Some(rx_socket) = rx_socket {
            rx_socket
        } else if let Some(recv_addr) = recv_addr {
            syscall::bind_udp(recv_addr, reuse_port)?
        } else {
            return Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                "missing bind address",
            ));
        };

        let rx_addr = convert_addr_to_std(rx_socket.local_addr()?)?;

        let tx_socket = if let Some(tx_socket) = tx_socket {
            tx_socket
        } else if let Some(send_addr) = send_addr {
            syscall::bind_udp(send_addr, reuse_port)?
        } else {
            // No tx_socket or send address was specified, so the tx socket
            // will be a handle to the rx socket.
            rx_socket.try_clone()?
        };

        if let Some(size) = socket_send_buffer_size {
            tx_socket.set_send_buffer_size(size)?;
        }

        if let Some(size) = socket_recv_buffer_size {
            rx_socket.set_recv_buffer_size(size)?;
        }

        // Configure MTU discovery
        if !syscall::configure_mtu_disc(&tx_socket) {
            // disable MTU probing if we can't prevent fragmentation
            max_mtu = MaxMtu::MIN;
        }

        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
            configuration: event::builder::PlatformFeatureConfiguration::MaxMtu {
                mtu: max_mtu.into(),
            },
        });

        // Configure the socket with GRO
        let gro_enabled = gro_enabled.unwrap_or(true) && syscall::configure_gro(&rx_socket);

        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
            configuration: event::builder::PlatformFeatureConfiguration::Gro {
                enabled: gro_enabled,
            },
        });

        // Configure packet info CMSG
        syscall::configure_pktinfo(&rx_socket);

        // Configure TOS/ECN
        let tos_enabled = syscall::configure_tos(&rx_socket);

        publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
            configuration: event::builder::PlatformFeatureConfiguration::Ecn {
                enabled: tos_enabled,
            },
        });

        let rx = {
            // if GRO is enabled, then we need to provide the syscall with the maximum size buffer
            let payload_len = if gro_enabled {
                u16::MAX
            } else {
                max_mtu.into()
            } as u32;

            let rx_buffer_size = queue_recv_buffer_size.unwrap_or(8 * (1 << 20));
            let entries = rx_buffer_size / payload_len;
            let entries = if entries.is_power_of_two() {
                entries
            } else {
                // round up to the nearest power of two, since the ring buffers require it
                entries.next_power_of_two()
            };

            let mut consumers = vec![];

            let rx_socket_count = parse_env("S2N_QUIC_UNSTABLE_RX_SOCKET_COUNT").unwrap_or(1);

            for idx in 0usize..rx_socket_count {
                let (producer, consumer) = socket::ring::pair(entries, payload_len);
                consumers.push(consumer);

                // spawn a task that actually reads from the socket into the ring buffer
                if idx + 1 == rx_socket_count {
                    handle.spawn(task::rx(rx_socket, producer));
                    break;
                } else {
                    let rx_socket = rx_socket.try_clone()?;
                    handle.spawn(task::rx(rx_socket, producer));
                }
            }

            // construct the RX side for the endpoint event loop
            let max_mtu = MaxMtu::try_from(payload_len as u16).unwrap();
            let addr: inet::SocketAddress = rx_addr.into();
            socket::io::rx::Rx::new(consumers, max_mtu, addr.into())
        };

        let tx = {
            let gso = crate::features::Gso::from(max_segments);

            // compute the payload size for each message from the number of GSO segments we can
            // fill
            let payload_len = {
                let max_mtu: u16 = max_mtu.into();
                (max_mtu as u32 * gso.max_segments() as u32).min(u16::MAX as u32)
            };

            let tx_buffer_size = queue_send_buffer_size.unwrap_or(128 * 1024);
            let entries = tx_buffer_size / payload_len;
            let entries = if entries.is_power_of_two() {
                entries
            } else {
                // round up to the nearest power of two, since the ring buffers require it
                entries.next_power_of_two()
            };

            let mut producers = vec![];

            let tx_socket_count = parse_env("S2N_QUIC_UNSTABLE_TX_SOCKET_COUNT").unwrap_or(1);

            for idx in 0usize..tx_socket_count {
                let (producer, consumer) = socket::ring::pair(entries, payload_len);
                producers.push(producer);

                // spawn a task that actually flushes the ring buffer to the socket
                if idx + 1 == tx_socket_count {
                    handle.spawn(task::tx(tx_socket, consumer, gso.clone()));
                    break;
                } else {
                    let tx_socket = tx_socket.try_clone()?;
                    handle.spawn(task::tx(tx_socket, consumer, gso.clone()));
                }
            }

            // construct the TX side for the endpoint event loop
            socket::io::tx::Tx::new(producers, gso, max_mtu)
        };

        // Notify the endpoint of the MTU that we chose
        endpoint.set_max_mtu(max_mtu);

        let task = handle.spawn(
            EventLoop {
                endpoint,
                clock,
                rx,
                tx,
            }
            .start(),
        );

        drop(guard);

        Ok((task, rx_addr.into()))
    }
}

fn convert_addr_to_std(addr: socket2::SockAddr) -> io::Result<std::net::SocketAddr> {
    addr.as_socket()
        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid domain for socket"))
}

fn parse_env<T: core::str::FromStr>(name: &str) -> Option<T> {
    std::env::var(name).ok().and_then(|v| v.parse().ok())
}