// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use core::{fmt, marker::PhantomData, str::FromStr, time::Duration}; use s2n_quic::provider::io::testing::rand; use serde::Deserialize; #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct CliRange { pub start: T, pub end: T, } impl Default for CliRange { fn default() -> Self { Self { start: 0.0, end: 0.0, } } } impl Default for CliRange { fn default() -> Self { Self { start: 0, end: 0 } } } impl Default for CliRange { fn default() -> Self { Self { start: Duration::ZERO.into(), end: Duration::ZERO.into(), } } } impl fmt::Display for CliRange { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.start == self.end { self.start.fmt(f) } else { write!(f, "{}..{}", self.start, self.end) } } } impl CliRange where T: Copy + PartialOrd + ::rand::distributions::uniform::SampleUniform, { pub fn gen(&self) -> T { if self.start == self.end { return self.start; } rand::gen_range(self.start..self.end) } } impl CliRange { pub fn gen_duration(&self) -> Duration { let start = self.start.as_nanos(); let end = self.end.as_nanos(); if start == end { return Duration::from_nanos(start as _); } let nanos = rand::gen_range(start..end); Duration::from_nanos(nanos as _) } } impl FromStr for CliRange { type Err = T::Err; fn from_str(s: &str) -> Result { if let Some((start, end)) = s.split_once("..") { let start = start.parse()?; let end = end.parse()?; Ok(Self { start, end }) } else { let start = s.parse()?; let end = start; Ok(Self { start, end }) } } } impl<'de, T> Deserialize<'de> for CliRange where T: Copy + FromStr, ::Err: core::fmt::Display, { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { deserializer.deserialize_any(Visitor::(PhantomData)) } } struct Visitor(PhantomData); impl<'de, T> serde::de::Visitor<'de> for Visitor where T: Copy + FromStr, ::Err: core::fmt::Display, { type Value = CliRange; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { write!(formatter, "a range or individual value") } fn visit_i64(self, v: i64) -> Result where E: serde::de::Error, { v.to_string().parse().map_err(E::custom) } fn visit_u64(self, v: u64) -> Result where E: serde::de::Error, { v.to_string().parse().map_err(E::custom) } fn visit_f64(self, v: f64) -> Result where E: serde::de::Error, { v.to_string().parse().map_err(E::custom) } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { v.parse().map_err(E::custom) } }