// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use bach::time::scheduler; use core::task::Poll; use s2n_quic_core::{ endpoint::Endpoint, inet::SocketAddress, io::event_loop::EventLoop, path::MaxMtu, }; type Error = std::io::Error; type Result = core::result::Result; pub mod message; mod model; pub mod network; mod socket; pub mod time; pub use model::{Model, TxRecorder}; pub use network::{Network, PathHandle}; pub use socket::Socket; pub use time::now; pub use bach::task::{self, primary, spawn}; pub mod rand { pub use ::bach::rand::*; #[derive(Clone, Copy, Default)] pub struct Havoc; impl s2n_quic_core::havoc::Random for Havoc { #[inline] fn fill(&mut self, bytes: &mut [u8]) { fill_bytes(bytes); } #[inline] fn gen_bool(&mut self) -> bool { gen() } #[inline] fn shuffle(&mut self, bytes: &mut [u8]) { shuffle(bytes); } #[inline] fn gen_range(&mut self, range: core::ops::Range) -> u64 { gen_range(range) } } } pub mod executor { pub use bach::executor::{Handle, JoinHandle}; } pub struct Executor { executor: bach::executor::Executor>, handle: Handle, } impl Executor { pub fn new(network: N, seed: u64) -> Self { let mut executor = bach::executor::Executor::new(|handle| Env { handle: handle.clone(), time: scheduler::Scheduler::new(), rand: bach::rand::Scope::new(seed), buffers: network::Buffers::default(), network, stalled_iterations: 0, }); let handle = Handle { executor: executor.handle().clone(), buffers: executor.environment().buffers.clone(), }; Self { executor, handle } } pub fn handle(&self) -> &Handle { &self.handle } pub fn enter O, O>(&mut self, f: F) -> O { self.executor.environment().enter(f) } pub fn run(&mut self) { self.executor.block_on_primary(); } pub fn close(&mut self) { // close the environment, which notifies all of the tasks that we're shutting down self.executor.environment().close(|| {}); while self.executor.macrostep() > 0 {} // then close the actual executor self.executor.close() } } impl Drop for Executor { fn drop(&mut self) { self.close(); } } struct Env { handle: bach::executor::Handle, time: scheduler::Scheduler, rand: bach::rand::Scope, buffers: network::Buffers, network: N, stalled_iterations: usize, } impl Env { fn enter O, O>(&self, f: F) -> O { self.handle.enter(|| self.time.enter(|| self.rand.enter(f))) } fn close(&mut self, f: F) { let handle = &mut self.handle; let rand = &mut self.rand; let time = &mut self.time; let buffers = &mut self.buffers; handle.enter(|| { rand.enter(|| { time.close(); time.enter(|| { buffers.close(); f(); }); }) }) } } impl bach::executor::Environment for Env { fn run(&mut self, tasks: Tasks) -> Poll<()> where Tasks: Iterator + Send, F: 'static + FnOnce() -> Poll<()> + Send, { let mut is_ready = true; let Self { handle, time, rand, buffers, network, .. } = self; handle.enter(|| { time.enter(|| { rand.enter(|| { for task in tasks { is_ready &= task().is_ready(); } network.execute(buffers); }) }) }); if is_ready { Poll::Ready(()) } else { Poll::Pending } } fn on_macrostep(&mut self, count: usize) { // only advance time after a stall if count > 0 { self.stalled_iterations = 0; return; } self.stalled_iterations += 1; // A stalled iteration is a macrostep that didn't actually execute any tasks. // // The idea with limiting it prevents the runtime from looping endlessly and not // actually doing any work. The value of 100 was chosen somewhat arbitrarily as a high // enough number that we won't get false positives but low enough that the number of // loops stays within reasonable ranges. if self.stalled_iterations > 100 { panic!("the runtime stalled after 100 iterations"); } while let Some(time) = self.time.advance() { let _ = time; if self.time.wake() > 0 { // if a task has woken, then reset the stall count self.stalled_iterations = 0; break; } } } fn close(&mut self, close: F) where F: 'static + FnOnce() + Send, { Self::close(self, close) } } #[derive(Clone)] pub struct Handle { executor: executor::Handle, buffers: network::Buffers, } impl Handle { pub fn builder(&self) -> Builder { Builder { handle: self.clone(), address: None, on_socket: None, max_mtu: MaxMtu::default(), queue_recv_buffer_size: None, queue_send_buffer_size: None, } } } pub struct Builder { handle: Handle, address: Option, on_socket: Option>, max_mtu: MaxMtu, queue_recv_buffer_size: Option, queue_send_buffer_size: Option, } impl Builder { pub fn build(self) -> Result { Ok(Io { builder: self }) } pub fn with_max_mtu(mut self, max_mtu: u16) -> Self { self.max_mtu = max_mtu.try_into().unwrap(); self } pub fn on_socket(mut self, f: impl FnOnce(socket::Socket) + 'static) -> Self { self.on_socket = Some(Box::new(f)); self } /// Sets the size of the send buffer associated with the transmit side (internal to s2n-quic) pub fn with_internal_send_buffer_size( mut self, send_buffer_size: usize, ) -> std::io::Result { self.queue_send_buffer_size = Some(send_buffer_size.try_into().map_err(|err| { std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("{err}")) })?); Ok(self) } /// Sets the size of the send buffer associated with the receive side (internal to s2n-quic) pub fn with_internal_recv_buffer_size( mut self, recv_buffer_size: usize, ) -> std::io::Result { self.queue_recv_buffer_size = Some(recv_buffer_size.try_into().map_err(|err| { std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("{err}")) })?); Ok(self) } } pub struct Io { builder: Builder, } impl Io { pub fn start>( self, mut endpoint: E, ) -> Result<(executor::JoinHandle<()>, SocketAddress)> { let Builder { handle: Handle { executor, buffers }, address, on_socket, max_mtu, queue_recv_buffer_size, queue_send_buffer_size, } = self.builder; endpoint.set_max_mtu(max_mtu); let handle = address.unwrap_or_else(|| buffers.generate_addr()); let (tx, rx, socket) = buffers.register( handle, self.builder.max_mtu, queue_recv_buffer_size, queue_send_buffer_size, ); if let Some(on_socket) = on_socket { on_socket(socket); } let clock = time::Clock::default(); let event_loop = EventLoop { endpoint, clock, tx, rx, }; let join = executor.spawn(event_loop.start()); Ok((join, handle)) } }