From ad5ab0bef77617cc57c9b3e2e5fec2818d4bd2c2 Mon Sep 17 00:00:00 2001 From: Arjun Ramesh Date: Tue, 7 Apr 2026 10:17:42 -0400 Subject: [PATCH] Added flat type computation for interface types for RR --- crates/environ/src/component/types.rs | 336 ++++++++++++++++++ crates/environ/src/component/types_builder.rs | 74 +--- 2 files changed, 337 insertions(+), 73 deletions(-) diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index 5cbfde24c693..524039d33e8a 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -381,6 +381,225 @@ impl ComponentTypes { } } + /// Returns the flat representation of a function's params and returns for + /// the underlying core wasm function according to the Canonical ABI + /// + /// As per the Canonical ABI, when the representation is larger than MAX_FLAT_RESULTS + /// or MAX_FLAT_PARAMS, the core wasm function will take a pointer to the arg/result list. + /// Returns (param_storage, result_storage) + pub fn flat_func_type( + &self, + ty: &TypeFunc, + context: FlatFuncTypeContext, + ) -> ( + FlatTypesStorage, + FlatTypesStorage, + ) { + let mut params_storage = self + .flat_interface_type(&InterfaceType::Tuple(ty.params), MAX_FLAT_PARAMS) + .unwrap_or_else(|| { + let mut flat = FlatTypesStorage::new(); + flat.push(FlatType::I32, FlatType::I64); + flat + }); + let results_storage = self + .flat_interface_type(&InterfaceType::Tuple(ty.results), MAX_FLAT_RESULTS) + .unwrap_or_else(|| { + let mut flat = FlatTypesStorage::new(); + match context { + FlatFuncTypeContext::Lift => { + flat.push(FlatType::I32, FlatType::I64); + } + // For lowers, the retptr is passed as the last parameter + FlatFuncTypeContext::Lower => { + params_storage.push(FlatType::I32, FlatType::I64); + } + } + flat + }); + (params_storage, results_storage) + } + + fn flat_interface_type( + &self, + ty: &InterfaceType, + limit: usize, + ) -> Option> { + // Helper routines + let push = |storage: &mut FlatTypesStorage, t32: FlatType, t64: FlatType| -> bool { + storage.push(t32, t64); + (storage.len as usize) <= limit + }; + + let push_discrim = |storage: &mut FlatTypesStorage| -> bool { + push(storage, FlatType::I32, FlatType::I32) + }; + + let push_storage = + |storage: &mut FlatTypesStorage, other: Option>| -> bool { + other + .and_then(|other| { + let len = usize::from(storage.len); + let other_len = usize::from(other.len); + (len + other_len <= limit).then(|| { + storage.memory32[len..len + other_len] + .copy_from_slice(&other.memory32[..other_len]); + storage.memory64[len..len + other_len] + .copy_from_slice(&other.memory64[..other_len]); + storage.len += other.len; + }) + }) + .is_some() + }; + + let push_storage_n = |storage: &mut FlatTypesStorage, + other: Option>, + n: usize| + -> bool { + other + .and_then(|other| { + let len = usize::from(storage.len); + let other_len = usize::from(other.len); + let other_total_len = other_len * n; + (len + other_total_len <= limit).then(|| { + for _ in 0..n { + storage + .memory32 + .copy_from_slice(&other.memory32[..other_len]); + storage + .memory64 + .copy_from_slice(&other.memory64[..other_len]); + } + storage.len += other_total_len as u8; + }) + }) + .is_some() + }; + + // Case is broken down as: + // * None => No field + // * Some(None) => Invalid storage (overflow) + // * Some(storage) => Valid storage + let push_storage_variant_case = |storage: &mut FlatTypesStorage, + case: Option>>| + -> bool { + match case { + None => true, + Some(case) => { + case.and_then(|case| { + // Discriminant will make size[case] = limit overshoot + ((1 + case.len as usize) <= limit).then(|| { + // Skip 1 for discriminant + let dst = storage + .memory32 + .iter_mut() + .zip(&mut storage.memory64) + .skip(1); + for (i, ((t32, t64), (dst32, dst64))) in case + .memory32 + .iter() + .take(case.len as usize) + .zip(case.memory64.iter()) + .zip(dst) + .enumerate() + { + if i + 1 < usize::from(storage.len) { + // Populated Index + dst32.join(*t32); + dst64.join(*t64); + } else { + // New Index + storage.len += 1; + *dst32 = *t32; + *dst64 = *t64; + } + } + }) + }) + .is_some() + } + } + }; + + // Logic + let mut storage_buf = FlatTypesStorage::new(); + let storage = &mut storage_buf; + + match ty { + InterfaceType::U8 + | InterfaceType::S8 + | InterfaceType::Bool + | InterfaceType::U16 + | InterfaceType::S16 + | InterfaceType::U32 + | InterfaceType::S32 + | InterfaceType::Char + | InterfaceType::Own(_) + | InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_) + | InterfaceType::Borrow(_) + | InterfaceType::Enum(_) => push(storage, FlatType::I32, FlatType::I32), + + InterfaceType::U64 | InterfaceType::S64 => push(storage, FlatType::I64, FlatType::I64), + InterfaceType::Float32 => push(storage, FlatType::F32, FlatType::F32), + InterfaceType::Float64 => push(storage, FlatType::F64, FlatType::F64), + InterfaceType::String | InterfaceType::List(_) | InterfaceType::Map(_) => { + // Pointer pair + push(storage, FlatType::I32, FlatType::I64) + && push(storage, FlatType::I32, FlatType::I64) + } + + InterfaceType::Record(i) => self[*i] + .fields + .iter() + .all(|field| push_storage(storage, self.flat_interface_type(&field.ty, limit))), + InterfaceType::Tuple(i) => self[*i] + .types + .iter() + .all(|field| push_storage(storage, self.flat_interface_type(field, limit))), + InterfaceType::Flags(i) => match FlagsSize::from_count(self[*i].names.len()) { + FlagsSize::Size0 => true, + FlagsSize::Size1 | FlagsSize::Size2 => push(storage, FlatType::I32, FlatType::I32), + FlagsSize::Size4Plus(n) => (0..n) + .into_iter() + .all(|_| push(storage, FlatType::I32, FlatType::I32)), + }, + InterfaceType::Variant(i) => { + push_discrim(storage) + && self[*i].cases.values().all(|case| { + let case_flat = case.as_ref().map(|ty| self.flat_interface_type(ty, limit)); + push_storage_variant_case(storage, case_flat) + }) + } + InterfaceType::Option(i) => { + push_discrim(storage) + && push_storage_variant_case(storage, None) + && push_storage_variant_case( + storage, + Some(self.flat_interface_type(&self[*i].ty, limit)), + ) + } + InterfaceType::Result(i) => { + push_discrim(storage) + && push_storage_variant_case( + storage, + self[*i].ok.map(|ty| self.flat_interface_type(&ty, limit)), + ) + && push_storage_variant_case( + storage, + self[*i].err.map(|ty| self.flat_interface_type(&ty, limit)), + ) + } + InterfaceType::FixedLengthList(i) => push_storage_n( + storage, + self.flat_interface_type(&self[*i].element, limit), + self[*i].size as usize, + ), + } + .then_some(storage_buf) + } + /// Adds a new `table` to the list of resource tables for this component. pub fn push_resource_table(&mut self, table: TypeResourceTable) -> TypeResourceTableIndex { self.resource_tables.push(table) @@ -1223,6 +1442,15 @@ pub const MAX_FLAT_TYPES: usize = if MAX_FLAT_PARAMS > MAX_FLAT_RESULTS { MAX_FLAT_RESULTS }; +/// Maximum number of parameters that a core wasm function exported/imports through +/// components can contain according to the Canonical ABI. In particular, this +/// can includes one potential extra return pointer for canon.lower methods. +pub const MAX_FLAT_PARAMS_ABI: usize = MAX_FLAT_PARAMS + 1; + +/// Maximum number of results that a core wasm function exported/imports through +/// components can contain according to the Canonical ABI. +pub const MAX_FLAT_RESULTS_ABI: usize = MAX_FLAT_RESULTS; + const fn add_flat(a: Option, b: Option) -> Option { const MAX: u8 = MAX_FLAT_TYPES as u8; let sum = match (a, b) { @@ -1248,6 +1476,82 @@ const fn max_flat(a: Option, b: Option) -> Option { } } +/// Representation of flat types in 32-bit and 64-bit memory +/// +/// This could be represented as `Vec` but on 64-bit architectures +/// that's 24 bytes. Otherwise `FlatType` is 1 byte large and +/// `MAX_FLAT_TYPES` is 16, so it should ideally be more space-efficient to +/// use a flat array instead of a heap-based vector. +pub struct FlatTypesStorage { + /// Representation for 32-bit memory + pub memory32: [FlatType; N], + /// Representation for 64-bit memory + pub memory64: [FlatType; N], + + /// Tracks the number of flat types pushed into this storage. If this is + /// `MAX_FLAT_TYPES + 1` then this storage represents an un-reprsentable + /// type in flat types. + /// + /// This value should be the same on both `memory32` and `memory64` + pub len: u8, +} + +impl FlatTypesStorage { + /// Create a new, empty storage for flat types + pub const fn new() -> FlatTypesStorage { + FlatTypesStorage { + memory32: [FlatType::I32; N], + memory64: [FlatType::I32; N], + len: 0, + } + } + + /// Returns a reference to flat type representation + pub fn as_flat_types(&self) -> Option> { + let len = usize::from(self.len); + if len > N { + assert_eq!(len, N + 1); + None + } else { + Some(FlatTypes { + memory32: &self.memory32[..len], + memory64: &self.memory64[..len], + }) + } + } + + /// Pushes a new flat type into this list using `t32` for 32-bit memories + /// and `t64` for 64-bit memories. + /// + /// Returns whether the type was actually pushed or whether this list of + /// flat types just exceeded the maximum meaning that it is now + /// unrepresentable with a flat list of types. + pub fn push(&mut self, t32: FlatType, t64: FlatType) -> bool { + let len = usize::from(self.len); + if len < N { + self.memory32[len] = t32; + self.memory64[len] = t64; + self.len += 1; + true + } else { + // If this was the first one to go over then flag the length as + // being incompatible with a flat representation. + if len == N { + self.len += 1; + } + false + } + } + + /// Generate an iterator over the 32-bit flat encoding + pub fn iter32(&self) -> impl Iterator { + self.memory32 + .iter() + .take(self.len as usize) + .map(|f| f.byte_size()) + } +} + /// Flat representation of a type in just core wasm types. pub struct FlatTypes<'a> { /// The flat representation of this type in 32-bit memories. @@ -1277,3 +1581,35 @@ pub enum FlatType { F32, F64, } + +impl FlatType { + /// Constructs the "joined" representation for two flat types + pub fn join(&mut self, other: FlatType) { + if *self == other { + return; + } + *self = match (*self, other) { + (FlatType::I32, FlatType::F32) | (FlatType::F32, FlatType::I32) => FlatType::I32, + _ => FlatType::I64, + }; + } + + /// Return the size in bytes for this flat type + pub const fn byte_size(&self) -> u8 { + match self { + FlatType::I32 | FlatType::F32 => 4, + FlatType::I64 | FlatType::F64 => 8, + } + } +} + +/// Context under which the flat ABI is considered for functypes. +/// +/// Note that this is necessary since the same signature can have different +/// ABIs depending on whether it is a lifted function or a lowered function. +pub enum FlatFuncTypeContext { + /// Flattening args for a lifted function + Lift, + /// Flattening args for a lowered function + Lower, +} diff --git a/crates/environ/src/component/types_builder.rs b/crates/environ/src/component/types_builder.rs index 10422028a3dd..c47ee7598bff 100644 --- a/crates/environ/src/component/types_builder.rs +++ b/crates/environ/src/component/types_builder.rs @@ -925,78 +925,6 @@ where return idx; } -struct FlatTypesStorage { - // This could be represented as `Vec` but on 64-bit architectures - // that's 24 bytes. Otherwise `FlatType` is 1 byte large and - // `MAX_FLAT_TYPES` is 16, so it should ideally be more space-efficient to - // use a flat array instead of a heap-based vector. - memory32: [FlatType; MAX_FLAT_TYPES], - memory64: [FlatType; MAX_FLAT_TYPES], - - // Tracks the number of flat types pushed into this storage. If this is - // `MAX_FLAT_TYPES + 1` then this storage represents an un-reprsentable - // type in flat types. - len: u8, -} - -impl FlatTypesStorage { - const fn new() -> FlatTypesStorage { - FlatTypesStorage { - memory32: [FlatType::I32; MAX_FLAT_TYPES], - memory64: [FlatType::I32; MAX_FLAT_TYPES], - len: 0, - } - } - - fn as_flat_types(&self) -> Option> { - let len = usize::from(self.len); - if len > MAX_FLAT_TYPES { - assert_eq!(len, MAX_FLAT_TYPES + 1); - None - } else { - Some(FlatTypes { - memory32: &self.memory32[..len], - memory64: &self.memory64[..len], - }) - } - } - - /// Pushes a new flat type into this list using `t32` for 32-bit memories - /// and `t64` for 64-bit memories. - /// - /// Returns whether the type was actually pushed or whether this list of - /// flat types just exceeded the maximum meaning that it is now - /// unrepresentable with a flat list of types. - fn push(&mut self, t32: FlatType, t64: FlatType) -> bool { - let len = usize::from(self.len); - if len < MAX_FLAT_TYPES { - self.memory32[len] = t32; - self.memory64[len] = t64; - self.len += 1; - true - } else { - // If this was the first one to go over then flag the length as - // being incompatible with a flat representation. - if len == MAX_FLAT_TYPES { - self.len += 1; - } - false - } - } -} - -impl FlatType { - fn join(&mut self, other: FlatType) { - if *self == other { - return; - } - *self = match (*self, other) { - (FlatType::I32, FlatType::F32) | (FlatType::F32, FlatType::I32) => FlatType::I32, - _ => FlatType::I64, - }; - } -} - #[derive(Default)] struct TypeInformationCache { records: PrimaryMap, @@ -1013,7 +941,7 @@ struct TypeInformationCache { struct TypeInformation { depth: u32, - flat: FlatTypesStorage, + flat: FlatTypesStorage, has_borrow: bool, }