// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use crate::{ packet::number::{ packet_number_space::PacketNumberSpace, truncated_packet_number::TruncatedPacketNumber, PACKET_NUMBER_LEN_MASK, }, varint::VarInt, }; use s2n_codec::{u24, DecoderBuffer, DecoderBufferResult}; /// A fully-decoded and unprotected packet number length #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct PacketNumberLen { pub(crate) space: PacketNumberSpace, pub(crate) value: PacketNumberLenValue, } impl PacketNumberLen { pub const MAX_LEN: usize = U32_SIZE; /// Returns the max `PacketNumberLen` value for the given `PacketNumberSpace` pub const fn max(space: PacketNumberSpace) -> Self { Self { value: PacketNumberLenValue::U32, space, } } /// Returns the `PacketNumberSpace` for the given `PacketNumberLen` #[inline] pub const fn space(self) -> PacketNumberSpace { self.space } /// Decodes a `TruncatedPacketNumber` with the given `PacketNumberLen` #[inline] pub fn decode_truncated_packet_number( self, buffer: DecoderBuffer, ) -> DecoderBufferResult { self.value .decode_truncated_packet_number(buffer, self.space) } /// Returns a packet tag mask for the given `PacketNumberLen`. #[inline] pub fn into_packet_tag_mask(self) -> u8 { self.value.into_packet_tag_mask() } /// Returns the bytesize required for encoding the given `PacketNumberLen` #[inline] pub fn bytesize(self) -> usize { self.value.bytesize() } /// Returns the bitsize required for encoding the given `PacketNumberLen` #[inline] pub fn bitsize(self) -> usize { self.value.bitsize() } #[inline] pub(crate) fn truncate_packet_number(self, value: VarInt) -> TruncatedPacketNumber { self.value.truncate_packet_number(value, self.space) } #[inline] pub(crate) fn from_packet_tag(tag: u8, space: PacketNumberSpace) -> Self { Self { value: PacketNumberLenValue::from_packet_tag(tag), space, } } #[inline] pub(crate) fn from_varint(value: VarInt, space: PacketNumberSpace) -> Option { Some(Self { value: PacketNumberLenValue::from_varint(value)?, space, }) } } const U8_TAG: u8 = 0; // (8 / 8) - 1; const U16_TAG: u8 = (16 / 8) - 1; const U24_TAG: u8 = (24 / 8) - 1; const U32_TAG: u8 = (32 / 8) - 1; const U32_SIZE: usize = 32 / 8; const U8_MAX: u64 = (1 << 8) - 1; const U16_MAX: u64 = (1 << 16) - 1; const U24_MAX: u64 = (1 << 24) - 1; const U32_MAX: u64 = (1 << 32) - 1; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub(crate) enum PacketNumberLenValue { U8, U16, U24, U32, } impl PacketNumberLenValue { #[inline] pub fn decode_truncated_packet_number( self, buffer: DecoderBuffer, space: PacketNumberSpace, ) -> DecoderBufferResult { match self { Self::U8 => TruncatedPacketNumber::decode::(buffer, space), Self::U16 => TruncatedPacketNumber::decode::(buffer, space), Self::U24 => TruncatedPacketNumber::decode::(buffer, space), Self::U32 => TruncatedPacketNumber::decode::(buffer, space), } } #[inline] pub(crate) fn truncate_packet_number( self, value: VarInt, space: PacketNumberSpace, ) -> TruncatedPacketNumber { match self { Self::U8 => TruncatedPacketNumber::new(*value as u8, space), Self::U16 => TruncatedPacketNumber::new(*value as u16, space), Self::U24 => TruncatedPacketNumber::new(u24::new_truncated(*value as u32), space), Self::U32 => TruncatedPacketNumber::new(*value as u32, space), } } #[inline] pub fn into_packet_tag_mask(self) -> u8 { self as u8 } #[inline] pub fn bytesize(self) -> usize { self as usize + 1 } #[inline] pub fn bitsize(self) -> usize { self.bytesize() * 8 } #[inline] pub fn from_packet_tag(tag: u8) -> Self { match tag & PACKET_NUMBER_LEN_MASK { U8_TAG => Self::U8, U16_TAG => Self::U16, U24_TAG => Self::U24, U32_TAG => Self::U32, _ => unreachable!("the mask only allows for 4 valid values"), } } #[inline] pub fn from_varint(value: VarInt) -> Option { #[allow(clippy::match_overlapping_arm)] match *value { 0..=U8_MAX => Some(Self::U8), 0..=U16_MAX => Some(Self::U16), 0..=U24_MAX => Some(Self::U24), 0..=U32_MAX => Some(Self::U32), _ => None, } } } #[cfg(test)] mod tests { use super::*; /// the code relies on the variants to be in ascending order #[test] fn ordering_test() { assert!(PacketNumberLenValue::U8 < PacketNumberLenValue::U16); assert!(PacketNumberLenValue::U16 < PacketNumberLenValue::U24); assert!(PacketNumberLenValue::U24 < PacketNumberLenValue::U32); } }