Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tls_codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub use tls_vec::{

#[cfg(feature = "std")]
pub use quic_vec::{SecretVLBytes, rw as vlen};
pub use quic_vec::{VLByteSlice, VLBytes};
pub use quic_vec::{Bytes, VLByteSlice, VLBytes};

#[cfg(feature = "derive")]
pub use tls_codec_derive::{
Expand Down
140 changes: 137 additions & 3 deletions tls_codec/src/quic_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,65 @@ impl From<VLBytes> for Vec<u8> {
}
}

/// Variable-length encoded byte vectors with transparent serde serialization.
///
/// This is equivalent to [`VLBytes`] for TLS codec (de)serialization, but uses
/// `#[serde(transparent)]` so that formats like CBOR serialize the bytes
/// directly instead of wrapping them in a map with a field name.
#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct Bytes(
#[cfg_attr(feature = "serde", serde(serialize_with = "serde_bytes::serialize"))]
#[cfg_attr(
feature = "serde",
serde(deserialize_with = "serde_impl::de_vec_bytes_compat")
)]
Vec<u8>,
);

impl Bytes {
/// Generate a new variable-length byte vector.
pub fn new(vec: Vec<u8>) -> Self {
Self(vec)
}

fn vec(&self) -> &[u8] {
&self.0
}

fn vec_mut(&mut self) -> &mut Vec<u8> {
&mut self.0
}
}

impl_vl_bytes_generic!(Bytes);

#[cfg(feature = "std")]
impl Zeroize for Bytes {
fn zeroize(&mut self) {
self.0.zeroize();
}
}

impl From<Bytes> for Vec<u8> {
fn from(b: Bytes) -> Self {
b.0
}
}

impl From<VLBytes> for Bytes {
fn from(b: VLBytes) -> Self {
Self(b.vec)
}
}

impl From<Bytes> for VLBytes {
fn from(b: Bytes) -> Self {
Self { vec: b.0 }
}
}

#[inline(always)]
fn tls_serialize_bytes_len(bytes: &[u8]) -> usize {
let content_length = bytes.len();
Expand Down Expand Up @@ -327,9 +386,31 @@ impl Size for &VLBytes {
}
}

impl Size for Bytes {
#[inline(always)]
fn tls_serialized_len(&self) -> usize {
tls_serialize_bytes_len(self.as_slice())
}
}

impl DeserializeBytes for Bytes {
#[inline(always)]
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> {
let (vl, remainder) = VLBytes::tls_deserialize_bytes(bytes)?;
Ok((Self::from(vl), remainder))
}
}

impl Size for &Bytes {
#[inline(always)]
fn tls_serialized_len(&self) -> usize {
(*self).tls_serialized_len()
}
}

#[cfg(feature = "serde")]
mod serde_impl {
use std::{fmt, vec::Vec};
use std::{fmt, string::String, vec::Vec};

use serde::{Deserializer, de};

Expand All @@ -343,7 +424,7 @@ mod serde_impl {
type Value = Vec<u8>;

fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("either a byte blob or a sequence of u8")
f.write_str("a byte blob, a sequence of u8, or a map with a \"vec\" key")
}

// New format (native bytes; e.g., CBOR/Bincode/Msgpack)
Expand Down Expand Up @@ -372,6 +453,25 @@ mod serde_impl {
}
Ok(out)
}

// Legacy VLBytes format (map with "vec" key)
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut value: Option<Vec<u8>> = None;
while let Some(key) = map.next_key::<String>()? {
if key == "vec" {
if value.is_some() {
return Err(de::Error::duplicate_field("vec"));
}
value = Some(map.next_value()?);
} else {
return Err(de::Error::unknown_field(&key, &["vec"]));
}
}
value.ok_or_else(|| de::Error::missing_field("vec"))
}
}

deserializer.deserialize_any(BytesOrSeq)
Expand Down Expand Up @@ -573,6 +673,26 @@ mod rw_bytes {
tls_serialize_bytes(writer, self.0)
}
}

impl Serialize for Bytes {
#[inline(always)]
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
tls_serialize_bytes(writer, self.as_slice())
}
}

impl Serialize for &Bytes {
#[inline(always)]
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
(*self).tls_serialize(writer)
}
}

impl Deserialize for Bytes {
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
VLBytes::tls_deserialize(bytes).map(Self::from)
}
}
}

#[cfg(feature = "std")]
Expand Down Expand Up @@ -668,10 +788,19 @@ impl<'a> Arbitrary<'a> for VLBytes {
}
}

#[cfg(feature = "arbitrary")]
impl<'a> Arbitrary<'a> for Bytes {
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
let mut vec = Vec::arbitrary(u)?;
vec.truncate(ContentLength::MAX as usize);
Ok(Self(vec))
}
}

#[cfg(feature = "std")]
#[cfg(test)]
mod test {
use crate::{SecretVLBytes, VLByteSlice, VLBytes};
use crate::{Bytes, SecretVLBytes, VLByteSlice, VLBytes};
use std::println;

#[test]
Expand Down Expand Up @@ -700,6 +829,11 @@ mod test {
println!("{got}");
assert_eq!(expected_vl_bytes, got);

let expected_bytes = format!("Bytes {{ {expected} }}");
let got = format!("{:?}", Bytes::new(test.clone()));
println!("{got}");
assert_eq!(expected_bytes, got);

let expected_secret_vl_bytes = format!("SecretVLBytes {{ {expected} }}");
let got = format!("{:?}", SecretVLBytes::new(test.clone()));
println!("{got}");
Expand Down
37 changes: 36 additions & 1 deletion tls_codec/tests/serde_impls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![cfg(feature = "serde")]

use tls_codec::VLBytes;
use tls_codec::{Bytes, VLBytes};

// Old VLBytes without serde bytes serialization
#[derive(serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -33,3 +33,38 @@ fn serde_impls() {

assert_eq!(deserialized, old_deserialized);
}

#[test]
fn bytes_is_transparent() {
let data = vec![32; 128];
let bytes_value = Bytes::new(data.clone());
let vlbytes_value = VLBytes::new(data);

let mut bytes_serialized = Vec::new();
ciborium::into_writer(&bytes_value, &mut bytes_serialized).unwrap();
let mut vlbytes_serialized = Vec::new();
ciborium::into_writer(&vlbytes_value, &mut vlbytes_serialized).unwrap();

// Bytes (transparent) should produce smaller output than VLBytes (has field name)
assert!(bytes_serialized.len() < vlbytes_serialized.len());

// Bytes should roundtrip
let deserialized: Bytes = ciborium::from_reader(bytes_serialized.as_slice()).unwrap();
assert_eq!(deserialized, bytes_value);
}

#[test]
fn bytes_vlbytes_cross_deserialization() {
let data = vec![42; 64];
let bytes_value = Bytes::new(data.clone());
let vlbytes_value = VLBytes::new(data);

let mut bytes_serialized = Vec::new();
ciborium::into_writer(&bytes_value, &mut bytes_serialized).unwrap();
let mut vlbytes_serialized = Vec::new();
ciborium::into_writer(&vlbytes_value, &mut vlbytes_serialized).unwrap();

// Bytes can deserialize VLBytes-serialized data
let from_vlbytes: Bytes = ciborium::from_reader(vlbytes_serialized.as_slice()).unwrap();
assert_eq!(from_vlbytes, bytes_value);
}
Loading