diff --git a/Cargo.lock b/Cargo.lock index 4a42447a83f..83d422147d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7691,6 +7691,17 @@ dependencies = [ "walkdir", ] +[[package]] +name = "spacetimedb-bindings-cpp-ffi" +version = "2.1.0" +dependencies = [ + "bytemuck", + "spacetimedb-lib 2.1.0", + "spacetimedb-primitives 2.1.0", + "spacetimedb-sats 2.1.0", + "thiserror 1.0.69", +] + [[package]] name = "spacetimedb-bindings-macro" version = "1.9.0" diff --git a/Cargo.toml b/Cargo.toml index e8938a0dcc5..57fd7c82deb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "crates/bindings-sys", "crates/bindings", "crates/bindings-macro", + "crates/bindings-cpp-ffi", "crates/cli", "crates/client-api", "crates/client-api-messages", diff --git a/crates/bindings-cpp-ffi/Cargo.toml b/crates/bindings-cpp-ffi/Cargo.toml new file mode 100644 index 00000000000..ca02d1f5748 --- /dev/null +++ b/crates/bindings-cpp-ffi/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "spacetimedb-bindings-cpp-ffi" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +description = "Rust replacement for SpacetimeDB C++ bindings type registration and FFI layer" +publish = false + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +spacetimedb-sats = { workspace = true } +spacetimedb-lib = { workspace = true } +spacetimedb-primitives = { workspace = true } +bytemuck = { workspace = true, features = ["must_cast"] } +thiserror = { workspace = true } + +[lints] +workspace = true diff --git a/crates/bindings-cpp-ffi/src/ffi.rs b/crates/bindings-cpp-ffi/src/ffi.rs new file mode 100644 index 00000000000..3bc1fa51e97 --- /dev/null +++ b/crates/bindings-cpp-ffi/src/ffi.rs @@ -0,0 +1,1006 @@ +//! FFI exports matching the SpacetimeDB WASM ABI. +//! +//! These functions are the Rust equivalents of the C++ exports in +//! `module_exports.cpp` and `Module.cpp`. + +#![allow(clippy::disallowed_macros)] + +use crate::module_type_registration::{serialize_module_def, ModuleTypeRegistration, RegistrationError}; +use spacetimedb_lib::bsatn; +use spacetimedb_lib::db::raw_def::v10::{RawModuleDefV10, RawModuleDefV10Section, RawScopedTypeNameV10, RawTypeDefV10}; +use spacetimedb_lib::{ConnectionId, Identity}; +use spacetimedb_sats::raw_identifier::RawIdentifier; +use spacetimedb_sats::AlgebraicType; + +// ============================================================ +// Opaque FFI handles (mirroring the C++ opaque types) +// ============================================================ + +/// Opaque handle provided by the host to write bytes into. +#[repr(C)] +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct BytesSink { + pub inner: u32, +} + +/// Opaque handle provided by the host to read bytes from. +#[repr(C)] +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct BytesSource { + pub inner: u32, +} + +impl BytesSource { + /// Sentinel value indicating an invalid / absent source. + pub const INVALID: Self = Self { inner: 0 }; +} + +// ============================================================ +// Host FFI imports +// ============================================================ + +#[cfg(not(test))] +#[link(wasm_import_module = "spacetimedb")] +unsafe extern "C" { + /// Write bytes to a sink. Returns 0 on success, negative on error. + fn bytes_sink_write(sink: BytesSink, data: *const u8, len: *mut usize) -> i16; + /// Read bytes from a source. Returns -1 when exhausted, 0 on success, negative on error. + fn bytes_source_read(source: BytesSource, buf: *mut u8, len: *mut usize) -> i16; + /// Get the remaining length of a source. Returns 0 on success, negative on error. + fn bytes_source_remaining_length(source: BytesSource, len: *mut u32) -> i16; +} + +// Stub implementations for native testing +#[cfg(test)] +mod host_stubs { + #![allow(dead_code)] + #![allow(unsafe_op_in_unsafe_fn)] + use super::*; + + use std::collections::HashMap; + use std::sync::{Mutex, OnceLock}; + + fn test_sources() -> &'static Mutex>> { + static ONCE: OnceLock>>> = OnceLock::new(); + ONCE.get_or_init(|| Mutex::new(HashMap::new())) + } + + fn test_sinks() -> &'static Mutex>> { + static ONCE: OnceLock>>> = OnceLock::new(); + ONCE.get_or_init(|| Mutex::new(HashMap::new())) + } + + static NEXT_HANDLE: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(1); + + pub fn register_test_source(data: Vec) -> BytesSource { + let handle = NEXT_HANDLE.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + test_sources().lock().unwrap().insert(handle, data); + BytesSource { inner: handle } + } + + pub fn register_test_sink() -> BytesSink { + let handle = NEXT_HANDLE.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + test_sinks().lock().unwrap().insert(handle, Vec::new()); + BytesSink { inner: handle } + } + + pub fn get_sink_data(sink: &BytesSink) -> Vec { + test_sinks() + .lock() + .unwrap() + .get(&sink.inner) + .cloned() + .unwrap_or_default() + } + + pub unsafe fn bytes_sink_write(sink: BytesSink, data: *const u8, len: *mut usize) -> i16 { + if let Some(buf) = test_sinks().lock().unwrap().get_mut(&sink.inner) { + // SAFETY: caller guarantees `len` and `data` are valid + let len_ref = unsafe { &mut *len }; + let slice = unsafe { std::slice::from_raw_parts(data, *len_ref) }; + buf.extend_from_slice(slice); + 0 + } else { + errno::NO_SUCH_BYTES + } + } + + pub unsafe fn bytes_source_read(source: BytesSource, buf: *mut u8, len: *mut usize) -> i16 { + if let Some(data) = test_sources().lock().unwrap().get(&source.inner) { + // SAFETY: caller guarantees `len` and `buf` are valid + let len_ref = unsafe { &mut *len }; + let max = *len_ref; + let available = data.len(); + let to_copy = max.min(available); + if to_copy > 0 { + unsafe { std::ptr::copy_nonoverlapping(data.as_ptr(), buf, to_copy) }; + *len_ref = to_copy; + } + -1 + } else { + errno::NO_SUCH_BYTES + } + } + + pub unsafe fn bytes_source_remaining_length(source: BytesSource, len: *mut u32) -> i16 { + if let Some(data) = test_sources().lock().unwrap().get(&source.inner) { + // SAFETY: caller guarantees `len` is valid + unsafe { *len = data.len() as u32 }; + 0 + } else { + errno::NO_SUCH_BYTES as i32 as i16 + } + } +} + +#[cfg(test)] +use host_stubs::*; + +// ============================================================ +// Status codes (matching C++ StatusCode enum) +// ============================================================ + +mod errno { + pub const OK: i16 = 0; + pub const HOST_CALL_FAILURE: i16 = 1; + pub const NO_SUCH_BYTES: i16 = 2; + pub const NO_SUCH_REDUCER: i16 = 3; + pub const NO_SUCH_VIEW: i16 = 4; + pub const NO_SUCH_PROCEDURE: i16 = 5; +} + +// ============================================================ +// View result header +// ============================================================ + +/// Prepended to view results to indicate the result type. +#[repr(u8)] +enum ViewResultHeader { + RowData = 0, + #[allow(unused)] + RawSql = 1, +} + +// ============================================================ +// Reducer / View / Procedure handler registration +// ============================================================ + +/// A reducer function: takes `(ReducerContext, &[u8])` → `Result<(), Box>`. +/// `ReducerContext` is defined by the consuming SDK. +pub type ReducerFn = for<'a> fn(&[u8]) -> Result<(), Box>; + +/// A view function: takes `&[u8]` → `Vec` (BSATN-encoded rows). +pub type ViewFn = for<'a> fn(&[u8]) -> Vec; + +/// A procedure function: takes `&[u8]` → `Vec` (BSATN-encoded result). +pub type ProcedureFn = for<'a> fn(&[u8]) -> Vec; + +// ============================================================ +// Global module state +// ============================================================ + +struct GlobalModuleState { + type_reg: ModuleTypeRegistration, + reducers: Vec, + views: Vec, + views_anon: Vec, + procedures: Vec, + /// Error flags from constraint / primary-key / circular-ref detection + error_state: Option, +} + +#[derive(Clone, Debug)] +struct ErrorState { + variant: ErrorVariant, +} + +#[allow(dead_code)] +#[derive(Clone, Debug)] +enum ErrorVariant { + CircularReference { type_name: String }, + MultiplePrimaryKeys { table_name: String }, + ConstraintRegistration { code: String, details: String }, + TypeRegistration { message: String, type_description: String }, +} + +impl GlobalModuleState { + fn new() -> Self { + Self { + type_reg: ModuleTypeRegistration::new(), + reducers: Vec::new(), + views: Vec::new(), + views_anon: Vec::new(), + procedures: Vec::new(), + error_state: None, + } + } + + fn clear(&mut self) { + self.type_reg.clear(); + self.reducers.clear(); + self.views.clear(); + self.views_anon.clear(); + self.procedures.clear(); + self.error_state = None; + } +} + +struct SingletonCell(std::cell::UnsafeCell); +// SAFETY: WASM modules are single-threaded. This pattern is used in the Rust +// standard library for WASI singletons. +#[allow(clippy::mut_from_ref)] +unsafe impl Sync for SingletonCell {} +impl SingletonCell { + const fn new(val: T) -> Self { + Self(std::cell::UnsafeCell::new(val)) + } + #[allow(clippy::mut_from_ref)] + fn get(&self) -> &mut T { + unsafe { &mut *self.0.get() } + } +} + +static MODULE: SingletonCell> = SingletonCell::new(None); + +fn get_module() -> &'static mut GlobalModuleState { + let cell = &MODULE; + if cell.get().is_none() { + *cell.get() = Some(GlobalModuleState::new()); + } + cell.get().as_mut().unwrap() +} + +// ============================================================ +// Public registration APIs (called by SDK macros) +// ============================================================ + +/// Register a reducer function. Returns the reducer's index. +pub fn register_reducer(f: ReducerFn) -> u32 { + let module = get_module(); + let id = module.reducers.len() as u32; + module.reducers.push(f); + id +} + +/// Register a view function (with sender identity). Returns the view's index. +pub fn register_view(f: ViewFn) -> u32 { + let module = get_module(); + let id = module.views.len() as u32; + module.views.push(f); + id +} + +/// Register an anonymous view function (no sender). Returns the view's index. +pub fn register_view_anon(f: ViewFn) -> u32 { + let module = get_module(); + let id = module.views_anon.len() as u32; + module.views_anon.push(f); + id +} + +/// Register a procedure function. Returns the procedure's index. +pub fn register_procedure(f: ProcedureFn) -> u32 { + let module = get_module(); + let id = module.procedures.len() as u32; + module.procedures.push(f); + id +} + +/// Register a type with the module's typespace. +pub fn register_type(ty: AlgebraicType, explicit_name: &str) -> AlgebraicType { + let module = get_module(); + module.type_reg.register_type(ty, explicit_name) +} + +/// Check if the module has a registration error. +pub fn has_registration_error() -> bool { + get_module().type_reg.has_error() +} + +/// Get the registration error details. +pub fn registration_error() -> Option<&'static RegistrationError> { + get_module().type_reg.error() +} + +// ============================================================ +// Error module generation (matching C++ __preinit__99_validate_types) +// ============================================================ + +fn make_error_type(name: &str) -> RawTypeDefV10 { + RawTypeDefV10 { + source_name: RawScopedTypeNameV10 { + scope: Box::new([]), + source_name: RawIdentifier::new(name), + }, + ty: spacetimedb_sats::AlgebraicTypeRef(999_999), + custom_ordering: false, + } +} + +fn make_error_module(error_type: RawTypeDefV10) -> RawModuleDefV10 { + let mut module = RawModuleDefV10::default(); + module.sections.push(RawModuleDefV10Section::Types(vec![error_type])); + module +} + +fn sanitize_for_error_name(s: &str) -> String { + s.chars() + .map(|c| if c.is_ascii_alphanumeric() || c == '_' { c } else { '_' }) + .take(100) + .collect() +} + +fn extract_type_name_from_error(message: &str) -> &str { + if let Some(start) = message.find('\'') + && let Some(end) = message[start + 1..].find('\'') + { + return &message[start + 1..start + 1 + end]; + } + "unknown" +} + +/// Build error module definition if there's a registration error. +/// Returns `Some(RawModuleDefV10)` if an error was detected, `None` otherwise. +fn build_error_module_def(state: &mut GlobalModuleState) -> Option { + // 1. Circular reference error + if let Some(ErrorState { + variant: ErrorVariant::CircularReference { type_name }, + }) = &state.error_state + { + let name = format!("ERROR_CIRCULAR_REFERENCE_{type_name}"); + return Some(make_error_module(make_error_type(&name))); + } + + // 2. Multiple primary key error + if let Some(ErrorState { + variant: ErrorVariant::MultiplePrimaryKeys { table_name }, + }) = &state.error_state + { + let name = format!("ERROR_MULTIPLE_PRIMARY_KEYS_{table_name}"); + return Some(make_error_module(make_error_type(&name))); + } + + // 3. Constraint registration error + if let Some(ErrorState { + variant: ErrorVariant::ConstraintRegistration { code, details }, + }) = &state.error_state + { + let sanitized = sanitize_for_error_name(&format!("ERROR_CONSTRAINT_REGISTRATION_{code}")); + eprintln!("\n[CONSTRAINT REGISTRATION ERROR] Module cleared and replaced with error type: {sanitized}"); + eprintln!("Original error: {details}\n"); + return Some(make_error_module(make_error_type(&sanitized))); + } + + // 4. Type registration error + if let Some(err) = state.type_reg.error() { + let message = &err.message; + let type_description = &err.type_description; + + let error_name = if message.contains("Recursive type reference") { + let problematic = extract_type_name_from_error(message); + format!("ERROR_RECURSIVE_TYPE_{problematic}") + } else if message.contains("Missing type name") { + let sanitized = sanitize_for_error_name(type_description); + format!("ERROR_MISSING_TYPE_NAME_{sanitized}") + } else { + "ERROR_TYPE_REGISTRATION_FAILED".to_owned() + }; + + eprintln!("\n[TYPE ERROR] Module cleared and replaced with error type: {error_name}"); + eprintln!("Original error: {message}\n"); + return Some(make_error_module(make_error_type(&error_name))); + } + + None +} + +// ============================================================ +// Helper: read all bytes from a BytesSource +// ============================================================ + +fn read_bytes_source(source: BytesSource) -> Vec { + if source == BytesSource::INVALID { + return Vec::new(); + } + + let mut buf = Vec::new(); + + // Try to get the remaining length for efficient reservation + if let Some(len) = { + let mut len: u32 = 0; + let ret = unsafe { bytes_source_remaining_length(source, &mut len) }; + if ret == 0 { + Some(len as usize) + } else { + None + } + } { + buf.reserve(len); + } else { + buf.reserve(1024); + } + + // Read in a loop to handle partial reads + loop { + let spare = buf.spare_capacity_mut(); + let spare_len = spare.len(); + let mut buf_len = spare.len(); + let ptr = spare.as_mut_ptr().cast::(); + + let ret = unsafe { bytes_source_read(source, ptr, &mut buf_len) }; + + match ret { + -1 => { + // Exhausted — `buf_len` was written, advance now + if buf_len > 0 { + unsafe { buf.set_len(buf.len() + buf_len) }; + } + break; + } + 0 => { + // Partial read — `buf_len` was written, advance + unsafe { buf.set_len(buf.len() + buf_len) }; + if buf_len == spare_len { + buf.reserve(1024); + } + // else: partial read but not exhausted, loop again + } + _ => { + eprintln!("ERROR: Failed to read from BytesSource: {ret}"); + break; + } + } + } + + buf +} + +// ============================================================ +// Helper: write bytes to a BytesSink +// ============================================================ + +fn write_to_sink(sink: BytesSink, mut buf: &[u8]) { + if sink.inner == 0 || buf.is_empty() { + return; + } + + loop { + let mut len = buf.len(); + let ret = unsafe { bytes_sink_write(sink, buf.as_ptr(), &mut len) }; + + match ret { + 0 => { + // Advance past the written bytes + buf = &buf[len..]; + if buf.is_empty() { + break; + } + } + errno::NO_SUCH_BYTES => panic!("invalid BytesSink passed"), + errno::HOST_CALL_FAILURE => panic!("no space left at sink"), + _ => { + eprintln!("ERROR: Failed to write to BytesSink: {ret}"); + break; + } + } + } +} + +// ============================================================ +// Helper: reconstruct Identity from 4x u64 (little-endian) +// ============================================================ + +fn reconstruct_identity(s0: u64, s1: u64, s2: u64, s3: u64) -> Identity { + // Identity is 32 bytes, stored little-endian + let mut bytes = [0u8; 32]; + bytes[0..8].copy_from_slice(&s0.to_le_bytes()); + bytes[8..16].copy_from_slice(&s1.to_le_bytes()); + bytes[16..24].copy_from_slice(&s2.to_le_bytes()); + bytes[24..32].copy_from_slice(&s3.to_le_bytes()); + Identity::from_byte_array(bytes) +} + +// ============================================================ +// Helper: reconstruct ConnectionId from 2x u64 (little-endian) +// ============================================================ + +fn reconstruct_connection_id(c0: u64, c1: u64) -> Option { + let mut bytes = [0u8; 16]; + bytes[0..8].copy_from_slice(&c0.to_le_bytes()); + bytes[8..16].copy_from_slice(&c1.to_le_bytes()); + let conn_id = ConnectionId::from_le_byte_array(bytes); + if conn_id == ConnectionId::ZERO { + None + } else { + Some(conn_id) + } +} + +// ============================================================ +// __preinit__01_clear_global_state +// ============================================================ + +#[unsafe(export_name = "__preinit__01_clear_global_state")] +extern "C" fn preinit_clear_global_state() { + get_module().clear(); +} + +// ============================================================ +// __preinit__99_validate_types +// ============================================================ + +#[unsafe(export_name = "__preinit__99_validate_types")] +extern "C" fn preinit_validate_types() { + let state = get_module(); + + // If there's an error state, the module definition will be replaced + // with an error type that SpacetimeDB will reject with a clear message + if state.error_state.is_some() || state.type_reg.has_error() { + // Build the error module — this replaces the normal module + if let Some(_error_module) = build_error_module_def(state) { + // Error module is built; when __describe_module__ runs, + // it will serialize this error module instead + } + } +} + +// ============================================================ +// __describe_module__ +// ============================================================ + +#[unsafe(no_mangle)] +pub extern "C" fn __describe_module__(description: BytesSink) { + let state = get_module(); + + // Check for errors — if present, build error module + let buffer = if state.error_state.is_some() || state.type_reg.has_error() { + if let Some(error_module) = build_error_module_def(state) { + let versioned = spacetimedb_lib::RawModuleDef::V10(error_module); + bsatn::to_vec(&versioned).expect("failed to serialize error module") + } else { + // No error after all — serialize normal module + serialize_module_def(&state.type_reg) + } + } else { + serialize_module_def(&state.type_reg) + }; + + if !buffer.is_empty() { + write_to_sink(description, &buffer); + } +} + +// ============================================================ +// __call_reducer__ +// ============================================================ + +#[unsafe(no_mangle)] +pub extern "C" fn __call_reducer__( + id: u32, + _sender_0: u64, + _sender_1: u64, + _sender_2: u64, + _sender_3: u64, + _conn_id_0: u64, + _conn_id_1: u64, + _timestamp_us: u64, + args: BytesSource, + error: BytesSink, +) -> i16 { + let state = get_module(); + + // Validate reducer ID + if id as usize >= state.reducers.len() { + let msg = format!("Invalid reducer ID: {id}"); + write_to_sink(error, msg.as_bytes()); + return errno::NO_SUCH_REDUCER; + } + + // Read args + let args_bytes = read_bytes_source(args); + + // Dispatch + let reducer_fn = state.reducers[id as usize]; + let result = reducer_fn(&args_bytes); + + // Handle errors + match result { + Ok(()) => errno::OK, + Err(msg) => { + write_to_sink(error, msg.as_bytes()); + errno::HOST_CALL_FAILURE + } + } +} + +// ============================================================ +// __call_view__ (with sender identity) +// ============================================================ + +#[unsafe(no_mangle)] +pub extern "C" fn __call_view__( + id: u32, + sender_0: u64, + sender_1: u64, + sender_2: u64, + sender_3: u64, + args: BytesSource, + result: BytesSink, +) -> i16 { + let state = get_module(); + + // Validate view ID + if id as usize >= state.views.len() { + eprintln!("ERROR: Invalid view ID {id} (have {} views)", state.views.len()); + return errno::NO_SUCH_VIEW; + } + + // Reconstruct sender identity (C++ builds Identity but doesn't use it for views currently) + let _sender = reconstruct_identity(sender_0, sender_1, sender_2, sender_3); + + // Read args + let args_bytes = read_bytes_source(args); + + // Dispatch + let view_fn = state.views[id as usize]; + let result_data = view_fn(&args_bytes); + + // Serialize ViewResultHeader::RowData + result + let mut full_result = Vec::with_capacity(1 + result_data.len()); + full_result.push(ViewResultHeader::RowData as u8); + full_result.extend_from_slice(&result_data); + + write_to_sink(result, &full_result); + 2 // Success with data (new ABI) +} + +// ============================================================ +// __call_view_anon__ (no sender identity) +// ============================================================ + +#[unsafe(no_mangle)] +pub extern "C" fn __call_view_anon__(id: u32, args: BytesSource, result: BytesSink) -> i16 { + let state = get_module(); + + // Validate view ID + if id as usize >= state.views_anon.len() { + eprintln!( + "ERROR: Invalid anonymous view ID {id} (have {} anonymous views)", + state.views_anon.len() + ); + return errno::NO_SUCH_VIEW; + } + + // Read args + let args_bytes = read_bytes_source(args); + + // Dispatch + let view_fn = state.views_anon[id as usize]; + let result_data = view_fn(&args_bytes); + + // Serialize ViewResultHeader::RowData + result + let mut full_result = Vec::with_capacity(1 + result_data.len()); + full_result.push(ViewResultHeader::RowData as u8); + full_result.extend_from_slice(&result_data); + + write_to_sink(result, &full_result); + 2 // Success with data (new ABI) +} + +// ============================================================ +// __call_procedure__ +// ============================================================ + +#[unsafe(no_mangle)] +pub extern "C" fn __call_procedure__( + id: u32, + sender_0: u64, + sender_1: u64, + sender_2: u64, + sender_3: u64, + conn_id_0: u64, + conn_id_1: u64, + _timestamp_microseconds: u64, + args_source: BytesSource, + result_sink: BytesSink, +) -> i16 { + let state = get_module(); + + // Validate procedure ID + if id as usize >= state.procedures.len() { + eprintln!( + "ERROR: Invalid procedure ID {id} (have {} procedures)", + state.procedures.len() + ); + return errno::NO_SUCH_PROCEDURE; + } + + // Reconstruct sender identity (for context, though procedure fn doesn't receive it here) + let _sender = reconstruct_identity(sender_0, sender_1, sender_2, sender_3); + let _conn_id = reconstruct_connection_id(conn_id_0, conn_id_1); + + // Read args + let args_bytes = read_bytes_source(args_source); + + // Dispatch + let procedure_fn = state.procedures[id as usize]; + let result_data = procedure_fn(&args_bytes); + + // Write result + write_to_sink(result_sink, &result_data); + 0 // Success +} + +// ============================================================ +// Tests +// ============================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn register_and_dispatch_reducer() { + // Clear global state + get_module().clear(); + + // Register a simple reducer + let id = register_reducer(|_args| Ok(())); + assert_eq!(id, 0); + + // Register a second reducer + let id2 = register_reducer(|_args| Err("test error".into())); + assert_eq!(id2, 1); + + // Verify reducers are stored + let state = get_module(); + assert_eq!(state.reducers.len(), 2); + } + + #[test] + fn register_and_dispatch_view() { + get_module().clear(); + + let id = register_view(|_args| vec![1, 2, 3]); + assert_eq!(id, 0); + + let id2 = register_view(|_args| vec![4, 5]); + assert_eq!(id2, 1); + + let state = get_module(); + assert_eq!(state.views.len(), 2); + } + + #[test] + fn register_and_dispatch_view_anon() { + get_module().clear(); + + let id = register_view_anon(|_args| vec![10, 20]); + assert_eq!(id, 0); + + let state = get_module(); + assert_eq!(state.views_anon.len(), 1); + } + + #[test] + fn register_and_dispatch_procedure() { + get_module().clear(); + + let id = register_procedure(|_args| vec![100, 200]); + assert_eq!(id, 0); + + let state = get_module(); + assert_eq!(state.procedures.len(), 1); + } + + #[test] + fn read_bytes_source_invalid_returns_empty() { + let result = read_bytes_source(BytesSource::INVALID); + assert!(result.is_empty()); + } + + #[test] + fn write_to_sink_empty_buffer_noop() { + // Writing empty buffer should not panic + write_to_sink(BytesSink { inner: 0 }, &[]); + } + + #[test] + fn write_to_sink_zero_inner_noop() { + // Sink with inner=0 should not panic + write_to_sink(BytesSink { inner: 0 }, &[1, 2, 3]); + } + + #[test] + fn reconstruct_identity_little_endian() { + let identity = reconstruct_identity( + 0x0102030405060708, + 0x090a0b0c0d0e0f10, + 0x1112131415161718, + 0x191a1b1c1d1e1f20, + ); + let bytes = identity.to_byte_array(); + assert_eq!(&bytes[0..8], &0x0102030405060708u64.to_le_bytes()); + assert_eq!(&bytes[8..16], &0x090a0b0c0d0e0f10u64.to_le_bytes()); + assert_eq!(&bytes[16..24], &0x1112131415161718u64.to_le_bytes()); + assert_eq!(&bytes[24..32], &0x191a1b1c1d1e1f20u64.to_le_bytes()); + } + + #[test] + fn reconstruct_connection_id_little_endian() { + let conn_id = reconstruct_connection_id(0x0102030405060708, 0x090a0b0c0d0e0f10).unwrap(); + let bytes = conn_id.as_le_byte_array(); + assert_eq!(&bytes[0..8], &0x0102030405060708u64.to_le_bytes()); + assert_eq!(&bytes[8..16], &0x090a0b0c0d0e0f10u64.to_le_bytes()); + } + + #[test] + fn reconstruct_connection_id_zero_returns_none() { + let result = reconstruct_connection_id(0, 0); + assert!(result.is_none()); + } + + #[test] + fn sanitize_for_error_name_replaces_special_chars() { + assert_eq!(sanitize_for_error_name("foo bar"), "foo_bar"); + assert_eq!(sanitize_for_error_name("foo/bar"), "foo_bar"); + assert_eq!(sanitize_for_error_name("foo@bar"), "foo_bar"); + } + + #[test] + fn sanitize_for_error_name_truncates_long() { + let long = "a".repeat(200); + let result = sanitize_for_error_name(&long); + assert_eq!(result.len(), 100); + } + + #[test] + fn sanitize_for_error_name_keeps_alphanumeric() { + assert_eq!(sanitize_for_error_name("ABC_123"), "ABC_123"); + } + + #[test] + fn extract_type_name_from_error_finds_quoted_name() { + assert_eq!( + extract_type_name_from_error("Recursive type reference: 'MyType' is referencing itself"), + "MyType" + ); + } + + #[test] + fn extract_type_name_from_error_returns_unknown() { + assert_eq!(extract_type_name_from_error("no quotes here"), "unknown"); + } + + #[test] + fn make_error_type_correct_structure() { + let error_type = make_error_type("ERROR_TEST"); + assert_eq!(&*error_type.source_name.source_name, "ERROR_TEST"); + assert!(error_type.source_name.scope.is_empty()); + assert_eq!(error_type.ty.0, 999_999); + assert!(!error_type.custom_ordering); + } + + #[test] + fn make_error_module_has_types_section() { + let module = make_error_module(make_error_type("ERROR_TEST")); + assert_eq!(module.sections.len(), 1); + match &module.sections[0] { + RawModuleDefV10Section::Types(types) => { + assert_eq!(types.len(), 1); + assert_eq!(&*types[0].source_name.source_name, "ERROR_TEST"); + } + _ => panic!("expected Types section"), + } + } + + #[test] + fn view_result_header_values() { + assert_eq!(ViewResultHeader::RowData as u8, 0); + assert_eq!(ViewResultHeader::RawSql as u8, 1); + } + + #[test] + fn bytes_source_invalid_constant() { + assert_eq!(BytesSource::INVALID.inner, 0); + } + + #[test] + fn error_state_circular_reference() { + get_module().clear(); + get_module().error_state = Some(ErrorState { + variant: ErrorVariant::CircularReference { + type_name: "RecursiveType".to_owned(), + }, + }); + + let module = build_error_module_def(get_module()).unwrap(); + match &module.sections[0] { + RawModuleDefV10Section::Types(types) => { + assert_eq!( + &*types[0].source_name.source_name, + "ERROR_CIRCULAR_REFERENCE_RecursiveType" + ); + } + _ => panic!("expected Types section"), + } + } + + #[test] + fn error_state_multiple_primary_keys() { + get_module().clear(); + get_module().error_state = Some(ErrorState { + variant: ErrorVariant::MultiplePrimaryKeys { + table_name: "MyTable".to_owned(), + }, + }); + + let module = build_error_module_def(get_module()).unwrap(); + match &module.sections[0] { + RawModuleDefV10Section::Types(types) => { + assert_eq!( + &*types[0].source_name.source_name, + "ERROR_MULTIPLE_PRIMARY_KEYS_MyTable" + ); + } + _ => panic!("expected Types section"), + } + } + + #[test] + fn error_state_constraint_registration() { + get_module().clear(); + get_module().error_state = Some(ErrorState { + variant: ErrorVariant::ConstraintRegistration { + code: "PK_CONFLICT".to_owned(), + details: "Duplicate primary key".to_owned(), + }, + }); + + let module = build_error_module_def(get_module()).unwrap(); + match &module.sections[0] { + RawModuleDefV10Section::Types(types) => { + let name = &*types[0].source_name.source_name; + assert!(name.starts_with("ERROR_CONSTRAINT_REGISTRATION_PK_CONFLICT")); + } + _ => panic!("expected Types section"), + } + } + + #[test] + fn error_state_type_registration_recursive() { + get_module().clear(); + get_module() + .type_reg + .register_type(spacetimedb_sats::AlgebraicType::U8, ""); + // Manually set an error + get_module().type_reg.clear(); + // We can't easily set the internal error, so test via build_error_module_def with None + assert!(build_error_module_def(get_module()).is_none()); + } + + #[test] + fn no_error_returns_none() { + get_module().clear(); + assert!(build_error_module_def(get_module()).is_none()); + } + + #[test] + fn clear_resets_all_handlers() { + get_module().clear(); + register_reducer(|_| Ok(())); + register_view(|_| vec![]); + register_view_anon(|_| vec![]); + register_procedure(|_| vec![]); + + get_module().clear(); + + let state = get_module(); + assert!(state.reducers.is_empty()); + assert!(state.views.is_empty()); + assert!(state.views_anon.is_empty()); + assert!(state.procedures.is_empty()); + assert!(state.error_state.is_none()); + assert!(!state.type_reg.has_error()); + } +} diff --git a/crates/bindings-cpp-ffi/src/lib.rs b/crates/bindings-cpp-ffi/src/lib.rs new file mode 100644 index 00000000000..aa8f61b709c --- /dev/null +++ b/crates/bindings-cpp-ffi/src/lib.rs @@ -0,0 +1,14 @@ +//! Rust replacement for SpacetimeDB C++ bindings type registration system. +//! +//! This crate rewrites the core type registration logic from +//! `module_type_registration.cpp` in idiomatic Rust, using the existing +//! `spacetimedb-sats` and `spacetimedb-lib` crates for type definitions. + +mod ffi; +mod module_type_registration; + +pub use ffi::{ + has_registration_error, register_procedure, register_reducer, register_type, register_view, register_view_anon, + registration_error, +}; +pub use module_type_registration::*; diff --git a/crates/bindings-cpp-ffi/src/module_type_registration.rs b/crates/bindings-cpp-ffi/src/module_type_registration.rs new file mode 100644 index 00000000000..1ab367e3bbb --- /dev/null +++ b/crates/bindings-cpp-ffi/src/module_type_registration.rs @@ -0,0 +1,1349 @@ +//! Module type registration system. +//! +//! This module rewrites the core type registration logic from the C++ +//! `module_type_registration.cpp` in idiomatic Rust. It handles: +//! +//! - Type registration and caching +//! - Circular reference detection +//! - Primitive, special, option, result, and ScheduleAt type classification +//! - Complex type (struct/enum) processing for the module typespace + +use spacetimedb_lib::db::raw_def::v10::{RawModuleDefV10, RawModuleDefV10Section, RawScopedTypeNameV10, RawTypeDefV10}; +use spacetimedb_lib::RawModuleDef; +use spacetimedb_sats::raw_identifier::RawIdentifier; +use spacetimedb_sats::typespace::Typespace; +use spacetimedb_sats::{ + AlgebraicType, AlgebraicTypeRef, ArrayType, ProductType, ProductTypeElement, SumType, SumTypeVariant, +}; +use std::collections::{HashMap, HashSet}; + +/// Error state from type registration +#[derive(Debug, Clone)] +pub struct RegistrationError { + pub message: String, + pub type_description: String, +} + +/// Module type registration system. +/// +/// This is the Rust equivalent of C++ `ModuleTypeRegistration`. +/// It owns a typespace and type defs, and handles type registration. +pub struct ModuleTypeRegistration { + /// Cache mapping type names to typespace indices + type_name_cache: HashMap, + /// Set of types currently being registered (for cycle detection) + types_being_registered: HashSet, + /// Error if registration failed + error: Option, + /// The typespace being built + typespace: Typespace, + /// Type definition exports + type_defs: Vec, +} + +impl Default for ModuleTypeRegistration { + fn default() -> Self { + Self::new() + } +} + +impl ModuleTypeRegistration { + pub fn new() -> Self { + Self { + type_name_cache: HashMap::new(), + types_being_registered: HashSet::new(), + error: None, + typespace: Typespace { types: Vec::new() }, + type_defs: Vec::new(), + } + } + + pub fn clear(&mut self) { + self.type_name_cache.clear(); + self.types_being_registered.clear(); + self.error = None; + self.typespace.types.clear(); + self.type_defs.clear(); + } + + pub fn has_error(&self) -> bool { + self.error.is_some() + } + + pub fn error(&self) -> Option<&RegistrationError> { + self.error.as_ref() + } + + /// Access the typespace + pub fn typespace(&self) -> &Typespace { + &self.typespace + } + + /// Access the type definitions + pub fn type_defs(&self) -> &[RawTypeDefV10] { + &self.type_defs + } + + // ============================================================ + // Type classification (all `&self` free, pure functions) + // ============================================================ + + pub fn is_primitive(ty: &AlgebraicType) -> bool { + ty.is_bool() + || ty.is_u8() + || ty.is_u16() + || ty.is_u32() + || ty.is_u64() + || ty.is_u128() + || ty.is_u256() + || ty.is_i8() + || ty.is_i16() + || ty.is_i32() + || ty.is_i64() + || ty.is_i128() + || ty.is_i256() + || ty.is_f32() + || ty.is_f64() + || ty.is_string() + } + + pub fn is_special_type(ty: &AlgebraicType) -> bool { + let Some(product) = ty.as_product() else { return false }; + if product.elements.len() != 1 { + return false; + } + product.elements[0].has_name("__identity__") + || product.elements[0].has_name("__connection_id__") + || product.elements[0].has_name("__timestamp_micros_since_unix_epoch__") + || product.elements[0].has_name("__time_duration_micros__") + || product.elements[0].has_name("__uuid__") + } + + pub fn is_option_type(ty: &AlgebraicType) -> bool { + let Some(sum) = ty.as_sum() else { return false }; + sum.variants.len() == 2 && sum.variants[0].has_name("some") && sum.variants[1].has_name("none") + } + + pub fn is_result_type(ty: &AlgebraicType) -> bool { + let Some(sum) = ty.as_sum() else { return false }; + sum.variants.len() == 2 && sum.variants[0].has_name("ok") && sum.variants[1].has_name("err") + } + + pub fn is_schedule_at_type(ty: &AlgebraicType) -> bool { + let Some(sum) = ty.as_sum() else { return false }; + sum.variants.len() == 2 && sum.variants[0].has_name("Interval") && sum.variants[1].has_name("Time") + } + + pub fn is_unit_type(ty: &AlgebraicType) -> bool { + ty.as_product().is_some_and(|p| p.elements.is_empty()) + } + + // ============================================================ + // Type conversion helpers + // ============================================================ + + fn convert_unit_type() -> AlgebraicType { + AlgebraicType::Product(ProductType { + elements: Vec::new().into(), + }) + } + + fn convert_array(&mut self, elem: AlgebraicType) -> AlgebraicType { + let elem = self.register_type(elem, ""); + AlgebraicType::Array(ArrayType { + elem_ty: Box::new(elem), + }) + } + + fn convert_special_type(&mut self, ty: &AlgebraicType) -> AlgebraicType { + let Some(product) = ty.as_product() else { + return AlgebraicType::U8; + }; + let elements: Box<[_]> = product + .elements + .iter() + .map(|f| ProductTypeElement { + name: f.name.clone(), + algebraic_type: self.register_type(f.algebraic_type.clone(), ""), + }) + .collect(); + AlgebraicType::Product(ProductType { elements }) + } + + fn convert_inline_sum(&mut self, ty: &AlgebraicType) -> AlgebraicType { + let Some(sum) = ty.as_sum() else { + return AlgebraicType::U8; + }; + let variants: Vec<_> = sum + .variants + .iter() + .map(|v| SumTypeVariant { + name: v.name.clone(), + algebraic_type: self.register_type(v.algebraic_type.clone(), ""), + }) + .collect(); + AlgebraicType::Sum(SumType { + variants: variants.into(), + }) + } + + // ============================================================ + // Name handling + // ============================================================ + + pub fn extract_type_name(cpp_type: &str) -> String { + let mut name = cpp_type; + if let Some(pos) = name.rfind("::") { + name = &name[pos + 2..]; + } + if let Some(pos) = name.find('<') { + name = &name[..pos]; + } + name.to_owned() + } + + pub fn parse_namespace_and_name(qualified_name: &str) -> (Vec, RawIdentifier) { + if let Some(last_dot) = qualified_name.rfind('.') { + let ns = &qualified_name[..last_dot]; + let name = RawIdentifier::new(&qualified_name[last_dot + 1..]); + let scope: Vec<_> = ns + .split('.') + .filter(|s| !s.is_empty()) + .map(RawIdentifier::new) + .collect(); + (scope, name) + } else { + (Vec::new(), RawIdentifier::new(qualified_name)) + } + } + + pub fn describe_type(ty: &AlgebraicType) -> String { + match ty { + AlgebraicType::Ref(r) => format!("Ref({})", r.0), + AlgebraicType::Bool => "Bool".into(), + AlgebraicType::I8 => "I8".into(), + AlgebraicType::U8 => "U8".into(), + AlgebraicType::I16 => "I16".into(), + AlgebraicType::U16 => "U16".into(), + AlgebraicType::I32 => "I32".into(), + AlgebraicType::U32 => "U32".into(), + AlgebraicType::I64 => "I64".into(), + AlgebraicType::U64 => "U64".into(), + AlgebraicType::I128 => "I128".into(), + AlgebraicType::U128 => "U128".into(), + AlgebraicType::I256 => "I256".into(), + AlgebraicType::U256 => "U256".into(), + AlgebraicType::F32 => "F32".into(), + AlgebraicType::F64 => "F64".into(), + AlgebraicType::String => "String".into(), + AlgebraicType::Array(arr) => format!("Array({})", Self::describe_type(&arr.elem_ty)), + AlgebraicType::Product(p) => { + if p.elements.is_empty() { + return "Product{}".into(); + } + let elems: Vec<_> = p + .elements + .iter() + .map(|e| { + let t = Self::describe_type(&e.algebraic_type); + match &e.name { + Some(n) => format!("{n}: {t}"), + None => t, + } + }) + .collect(); + format!("Product{{{}}}", elems.join(", ")) + } + AlgebraicType::Sum(s) => { + if s.variants.is_empty() { + return "Sum{}".into(); + } + if Self::is_option_type(ty) { + return format!("Option<{}>", Self::describe_type(&s.variants[0].algebraic_type)); + } + let vars: Vec<_> = s + .variants + .iter() + .map(|v| { + let n = v.name.as_deref().unwrap_or_default(); + format!("{n}: {}", Self::describe_type(&v.algebraic_type)) + }) + .collect(); + format!("Sum{{{}}}", vars.join(" | ")) + } + } + } + + // ============================================================ + // Core registration + // ============================================================ + + pub fn register_type(&mut self, ty: AlgebraicType, explicit_name: &str) -> AlgebraicType { + // 1. Primitives + if Self::is_primitive(&ty) { + return ty; + } + // 2. Refs + if let AlgebraicType::Ref(r) = ty { + return AlgebraicType::Ref(r); + } + // 3. Arrays + if let AlgebraicType::Array(arr) = &ty { + return self.convert_array((*arr.elem_ty).clone()); + } + // 4. Unit types + if Self::is_unit_type(&ty) && explicit_name.is_empty() { + return Self::convert_unit_type(); + } + // 5. Special types + if Self::is_special_type(&ty) { + return self.convert_special_type(&ty); + } + // 5b. ScheduleAt + if Self::is_schedule_at_type(&ty) { + return self.convert_inline_sum(&ty); + } + // 6. Options + if Self::is_option_type(&ty) { + return self.convert_inline_sum(&ty); + } + // 7. Results + if Self::is_result_type(&ty) { + return self.convert_inline_sum(&ty); + } + + // === Complex types below === + + // 8. Type name + let mut type_name = if !explicit_name.is_empty() { + explicit_name.to_owned() + } else { + String::new() + }; + if let Some(pos) = type_name.rfind("::") { + type_name = type_name[pos + 2..].to_owned(); + } + + if type_name.is_empty() { + self.error = Some(RegistrationError { + type_description: Self::describe_type(&ty), + message: format!("Missing type name for complex type: {}", Self::describe_type(&ty)), + }); + return AlgebraicType::U8; + } + + // 9. Circular ref detection + if self.types_being_registered.contains(&type_name) { + self.error = Some(RegistrationError { + type_description: Self::describe_type(&ty), + message: format!("Recursive type reference detected: '{type_name}' is referencing itself"), + }); + return AlgebraicType::U8; + } + + // 10. Cache check + if let Some(&idx) = self.type_name_cache.get(&type_name) { + return AlgebraicType::Ref(idx); + } + + // 11. Register + self.register_complex_type(ty, &type_name) + } + + fn register_complex_type(&mut self, ty: AlgebraicType, type_name: &str) -> AlgebraicType { + self.types_being_registered.insert(type_name.to_owned()); + let idx = AlgebraicTypeRef(self.typespace.types.len() as u32); + + let processed = match &ty { + AlgebraicType::Product(_) => self.process_product(&ty), + AlgebraicType::Sum(_) => self.process_sum(&ty), + _ => { + self.types_being_registered.remove(type_name); + return Self::convert_unit_type(); + } + }; + + self.typespace.types.push(processed); + + let (scope, name) = Self::parse_namespace_and_name(type_name); + self.type_defs.push(RawTypeDefV10 { + source_name: RawScopedTypeNameV10 { + scope: scope.into(), + source_name: name, + }, + ty: idx, + custom_ordering: true, + }); + + self.type_name_cache.insert(type_name.to_owned(), idx); + self.types_being_registered.remove(type_name); + AlgebraicType::Ref(idx) + } + + fn process_product(&mut self, ty: &AlgebraicType) -> AlgebraicType { + let Some(product) = ty.as_product() else { + return AlgebraicType::U8; + }; + let elements: Box<[_]> = product + .elements + .iter() + .map(|f| ProductTypeElement { + name: f.name.clone(), + algebraic_type: self.register_type(f.algebraic_type.clone(), ""), + }) + .collect(); + AlgebraicType::Product(ProductType { elements }) + } + + fn process_sum(&mut self, ty: &AlgebraicType) -> AlgebraicType { + let Some(sum) = ty.as_sum() else { + return AlgebraicType::U8; + }; + let variants: Vec<_> = sum + .variants + .iter() + .map(|v| SumTypeVariant { + name: v.name.clone(), + algebraic_type: self.register_type(v.algebraic_type.clone(), ""), + }) + .collect(); + AlgebraicType::Sum(SumType { + variants: variants.into(), + }) + } + + /// Build the final `RawModuleDefV10` + pub fn build_module_def(&self) -> RawModuleDefV10 { + let mut module = RawModuleDefV10::default(); + module + .sections + .push(RawModuleDefV10Section::Typespace(self.typespace.clone())); + module + .sections + .push(RawModuleDefV10Section::Types(self.type_defs.clone())); + module + } +} + +/// Serialize the module definition to BSATN bytes +pub fn serialize_module_def(reg: &ModuleTypeRegistration) -> Vec { + let module = reg.build_module_def(); + let versioned = RawModuleDef::V10(module); + spacetimedb_sats::bsatn::to_vec(&versioned).expect("failed to serialize module definition") +} + +#[cfg(test)] +mod tests { + use super::*; + + // ============================================================ + // Type classification + // ============================================================ + + #[test] + fn is_primitive_all() { + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::Bool)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::U8)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::U16)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::U32)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::U64)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::U128)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::U256)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::I8)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::I16)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::I32)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::I64)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::I128)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::I256)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::F32)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::F64)); + assert!(ModuleTypeRegistration::is_primitive(&AlgebraicType::String)); + } + + #[test] + fn is_primitive_rejects_composite() { + assert!(!ModuleTypeRegistration::is_primitive(&AlgebraicType::Array( + ArrayType { + elem_ty: Box::new(AlgebraicType::U8) + } + ))); + assert!(!ModuleTypeRegistration::is_primitive(&AlgebraicType::Product( + ProductType { + elements: vec![].into() + } + ))); + assert!(!ModuleTypeRegistration::is_primitive(&AlgebraicType::Sum(SumType { + variants: vec![].into() + }))); + assert!(!ModuleTypeRegistration::is_primitive(&AlgebraicType::Ref( + AlgebraicTypeRef(0) + ))); + } + + #[test] + fn is_special_type_all_specials() { + for name in [ + "__identity__", + "__connection_id__", + "__timestamp_micros_since_unix_epoch__", + "__time_duration_micros__", + "__uuid__", + ] { + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new(name)), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + assert!(ModuleTypeRegistration::is_special_type(&ty), "{name} should be special"); + } + } + + #[test] + fn is_special_type_non_special() { + let normal = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + assert!(!ModuleTypeRegistration::is_special_type(&normal)); + assert!(!ModuleTypeRegistration::is_special_type(&AlgebraicType::U8)); + assert!(!ModuleTypeRegistration::is_special_type(&AlgebraicType::Sum(SumType { + variants: vec![].into() + }))); + } + + #[test] + fn is_special_type_multiple_fields() { + let multi = AlgebraicType::Product(ProductType { + elements: vec![ + ProductTypeElement { + name: Some(RawIdentifier::new("__identity__")), + algebraic_type: AlgebraicType::U8, + }, + ProductTypeElement { + name: Some(RawIdentifier::new("other")), + algebraic_type: AlgebraicType::U8, + }, + ] + .into(), + }); + assert!(!ModuleTypeRegistration::is_special_type(&multi)); + } + + fn make_option() -> AlgebraicType { + AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("some")), + algebraic_type: AlgebraicType::U8, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("none")), + algebraic_type: AlgebraicType::Product(ProductType { + elements: vec![].into(), + }), + }, + ] + .into(), + }) + } + + fn make_result() -> AlgebraicType { + AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("ok")), + algebraic_type: AlgebraicType::U8, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("err")), + algebraic_type: AlgebraicType::String, + }, + ] + .into(), + }) + } + + fn make_schedule_at() -> AlgebraicType { + AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("Interval")), + algebraic_type: AlgebraicType::I64, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("Time")), + algebraic_type: AlgebraicType::I64, + }, + ] + .into(), + }) + } + + #[test] + fn is_option_type_valid() { + assert!(ModuleTypeRegistration::is_option_type(&make_option())); + } + #[test] + fn is_option_type_wrong_names() { + assert!(!ModuleTypeRegistration::is_option_type(&make_result())); + } + #[test] + fn is_option_type_wrong_count() { + let one = AlgebraicType::Sum(SumType { + variants: vec![SumTypeVariant { + name: Some(RawIdentifier::new("some")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + assert!(!ModuleTypeRegistration::is_option_type(&one)); + let three = AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("some")), + algebraic_type: AlgebraicType::U8, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("none")), + algebraic_type: AlgebraicType::U8, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("extra")), + algebraic_type: AlgebraicType::U8, + }, + ] + .into(), + }); + assert!(!ModuleTypeRegistration::is_option_type(&three)); + } + #[test] + fn is_option_type_non_sum() { + assert!(!ModuleTypeRegistration::is_option_type(&AlgebraicType::U8)); + assert!(!ModuleTypeRegistration::is_option_type(&AlgebraicType::Product( + ProductType { + elements: vec![].into() + } + ))); + } + + #[test] + fn is_result_type_valid() { + assert!(ModuleTypeRegistration::is_result_type(&make_result())); + } + #[test] + fn is_result_type_wrong_names() { + assert!(!ModuleTypeRegistration::is_result_type(&make_option())); + } + #[test] + fn is_result_type_non_sum() { + assert!(!ModuleTypeRegistration::is_result_type(&AlgebraicType::Bool)); + } + + #[test] + fn is_schedule_at_type_valid() { + assert!(ModuleTypeRegistration::is_schedule_at_type(&make_schedule_at())); + } + #[test] + fn is_schedule_at_type_wrong_names() { + let wrong = AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("Interval")), + algebraic_type: AlgebraicType::I64, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("Wrong")), + algebraic_type: AlgebraicType::I64, + }, + ] + .into(), + }); + assert!(!ModuleTypeRegistration::is_schedule_at_type(&wrong)); + } + #[test] + fn is_schedule_at_type_non_sum() { + assert!(!ModuleTypeRegistration::is_schedule_at_type(&AlgebraicType::Product( + ProductType { + elements: vec![].into() + } + ))); + } + + #[test] + fn is_unit_type_empty_product() { + let unit = AlgebraicType::Product(ProductType { + elements: vec![].into(), + }); + assert!(ModuleTypeRegistration::is_unit_type(&unit)); + } + #[test] + fn is_unit_type_non_unit() { + let non = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + assert!(!ModuleTypeRegistration::is_unit_type(&non)); + } + #[test] + fn is_unit_type_non_product() { + assert!(!ModuleTypeRegistration::is_unit_type(&AlgebraicType::Sum(SumType { + variants: vec![].into() + }))); + assert!(!ModuleTypeRegistration::is_unit_type(&AlgebraicType::U8)); + } + + // ============================================================ + // Name handling + // ============================================================ + + #[test] + fn extract_type_name_no_namespace() { + assert_eq!(ModuleTypeRegistration::extract_type_name("MyType"), "MyType"); + } + #[test] + fn extract_type_name_with_namespace() { + assert_eq!( + ModuleTypeRegistration::extract_type_name("SpacetimeDB::Internal::MyType"), + "MyType" + ); + } + #[test] + fn extract_type_name_with_template() { + assert_eq!(ModuleTypeRegistration::extract_type_name("std::vector"), "vector"); + assert_eq!(ModuleTypeRegistration::extract_type_name("MyType"), "MyType"); + } + #[test] + fn extract_type_name_template_no_namespace() { + assert_eq!( + ModuleTypeRegistration::extract_type_name("HashMap"), + "HashMap" + ); + } + + #[test] + fn parse_namespace_no_namespace() { + let (scope, name) = ModuleTypeRegistration::parse_namespace_and_name("MyType"); + assert!(scope.is_empty()); + assert_eq!(&*name, "MyType"); + } + #[test] + fn parse_namespace_single_level() { + let (scope, name) = ModuleTypeRegistration::parse_namespace_and_name("A.MyType"); + assert_eq!(scope, vec![RawIdentifier::new("A")]); + assert_eq!(&*name, "MyType"); + } + #[test] + fn parse_namespace_nested() { + let (scope, name) = ModuleTypeRegistration::parse_namespace_and_name("A.B.MyType"); + assert_eq!(scope, vec![RawIdentifier::new("A"), RawIdentifier::new("B")]); + assert_eq!(&*name, "MyType"); + } + #[test] + fn parse_namespace_deeply_nested() { + let (scope, name) = ModuleTypeRegistration::parse_namespace_and_name("SpacetimeDB.Internal.MyType"); + assert_eq!( + scope, + vec![RawIdentifier::new("SpacetimeDB"), RawIdentifier::new("Internal")] + ); + assert_eq!(&*name, "MyType"); + } + + // ============================================================ + // Describe type + // ============================================================ + + #[test] + fn describe_all_primitives() { + assert_eq!(ModuleTypeRegistration::describe_type(&AlgebraicType::Bool), "Bool"); + assert_eq!(ModuleTypeRegistration::describe_type(&AlgebraicType::U8), "U8"); + assert_eq!(ModuleTypeRegistration::describe_type(&AlgebraicType::U32), "U32"); + assert_eq!(ModuleTypeRegistration::describe_type(&AlgebraicType::I64), "I64"); + assert_eq!(ModuleTypeRegistration::describe_type(&AlgebraicType::F64), "F64"); + assert_eq!(ModuleTypeRegistration::describe_type(&AlgebraicType::String), "String"); + } + + #[test] + fn describe_array() { + let arr = AlgebraicType::Array(ArrayType { + elem_ty: Box::new(AlgebraicType::U32), + }); + assert_eq!(ModuleTypeRegistration::describe_type(&arr), "Array(U32)"); + } + #[test] + fn describe_array_nested() { + let arr = AlgebraicType::Array(ArrayType { + elem_ty: Box::new(AlgebraicType::Array(ArrayType { + elem_ty: Box::new(AlgebraicType::U8), + })), + }); + assert_eq!(ModuleTypeRegistration::describe_type(&arr), "Array(Array(U8))"); + } + + #[test] + fn describe_empty_product() { + assert_eq!( + ModuleTypeRegistration::describe_type(&AlgebraicType::Product(ProductType { + elements: vec![].into() + })), + "Product{}" + ); + } + #[test] + fn describe_product_unnamed() { + let ty = AlgebraicType::Product(ProductType { + elements: vec![ + ProductTypeElement { + name: None, + algebraic_type: AlgebraicType::U8, + }, + ProductTypeElement { + name: None, + algebraic_type: AlgebraicType::String, + }, + ] + .into(), + }); + let d = ModuleTypeRegistration::describe_type(&ty); + assert!(d.contains("U8")); + assert!(d.contains("String")); + assert!(d.starts_with("Product{")); + } + #[test] + fn describe_product_named() { + let ty = AlgebraicType::Product(ProductType { + elements: vec![ + ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U8, + }, + ProductTypeElement { + name: Some(RawIdentifier::new("y")), + algebraic_type: AlgebraicType::String, + }, + ] + .into(), + }); + let d = ModuleTypeRegistration::describe_type(&ty); + assert!(d.contains("x: U8")); + assert!(d.contains("y: String")); + } + + #[test] + fn describe_empty_sum() { + assert_eq!( + ModuleTypeRegistration::describe_type(&AlgebraicType::Sum(SumType { + variants: vec![].into() + })), + "Sum{}" + ); + } + #[test] + fn describe_option() { + assert_eq!(ModuleTypeRegistration::describe_type(&make_option()), "Option"); + } + #[test] + fn describe_sum_non_option() { + let sum = AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("A")), + algebraic_type: AlgebraicType::U8, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("B")), + algebraic_type: AlgebraicType::String, + }, + ] + .into(), + }); + let d = ModuleTypeRegistration::describe_type(&sum); + assert!(d.contains("A: U8")); + assert!(d.contains("B: String")); + assert!(d.contains(" | ")); + } + #[test] + fn describe_ref() { + assert_eq!( + ModuleTypeRegistration::describe_type(&AlgebraicType::Ref(AlgebraicTypeRef(42))), + "Ref(42)" + ); + } + + // ============================================================ + // Registration state + // ============================================================ + + #[test] + fn new_is_clean() { + let reg = ModuleTypeRegistration::new(); + assert!(!reg.has_error()); + assert!(reg.error().is_none()); + } + #[test] + fn clear_resets_state() { + let mut reg = ModuleTypeRegistration::new(); + reg.error = Some(RegistrationError { + message: "err".into(), + type_description: "desc".into(), + }); + reg.types_being_registered.insert("Foo".into()); + reg.type_name_cache.insert("Foo".into(), AlgebraicTypeRef(0)); + reg.typespace.types.push(AlgebraicType::U8); + reg.type_defs.push(RawTypeDefV10 { + source_name: RawScopedTypeNameV10 { + scope: vec![].into(), + source_name: RawIdentifier::new("X"), + }, + ty: AlgebraicTypeRef(0), + custom_ordering: false, + }); + reg.clear(); + assert!(!reg.has_error()); + assert!(reg.error().is_none()); + assert!(reg.types_being_registered.is_empty()); + assert!(reg.type_name_cache.is_empty()); + assert!(reg.typespace.types.is_empty()); + assert!(reg.type_defs.is_empty()); + } + #[test] + fn has_error_when_set() { + let mut reg = ModuleTypeRegistration::new(); + reg.error = Some(RegistrationError { + message: "fail".into(), + type_description: "desc".into(), + }); + assert!(reg.has_error()); + } + + #[test] + fn convert_unit_type() { + let unit = ModuleTypeRegistration::convert_unit_type(); + assert!(matches!(unit, AlgebraicType::Product(ref p) if p.elements.is_empty())); + } + + // ============================================================ + // register_type — primitives + // ============================================================ + + #[test] + fn register_primitive_bool() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!( + reg.register_type(AlgebraicType::Bool, ""), + AlgebraicType::Bool + )); + } + #[test] + fn register_primitive_u32() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!(reg.register_type(AlgebraicType::U32, ""), AlgebraicType::U32)); + } + #[test] + fn register_primitive_string() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!( + reg.register_type(AlgebraicType::String, ""), + AlgebraicType::String + )); + } + #[test] + fn register_primitive_ignores_name() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!( + reg.register_type(AlgebraicType::U64, "MyAlias"), + AlgebraicType::U64 + )); + } + + #[test] + fn register_ref_passthrough() { + let mut reg = ModuleTypeRegistration::new(); + assert!( + matches!(reg.register_type(AlgebraicType::Ref(AlgebraicTypeRef(5)), ""), AlgebraicType::Ref(r) if r.0 == 5) + ); + } + + #[test] + fn register_array_of_primitive() { + let mut reg = ModuleTypeRegistration::new(); + let arr = AlgebraicType::Array(ArrayType { + elem_ty: Box::new(AlgebraicType::U8), + }); + assert!(matches!(reg.register_type(arr, ""), AlgebraicType::Array(_))); + } + #[test] + fn register_array_preserves_element() { + let mut reg = ModuleTypeRegistration::new(); + let arr = AlgebraicType::Array(ArrayType { + elem_ty: Box::new(AlgebraicType::I64), + }); + let result = reg.register_type(arr, ""); + if let AlgebraicType::Array(a) = result { + assert!(matches!(*a.elem_ty, AlgebraicType::I64)); + } else { + panic!("expected Array"); + } + } + + // ============================================================ + // register_type — inlined composites + // ============================================================ + + #[test] + fn register_option_is_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let result = reg.register_type(make_option(), "MyOption"); + assert!(matches!(result, AlgebraicType::Sum(_))); + } + #[test] + fn register_result_is_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let result = reg.register_type(make_result(), "MyResult"); + assert!(matches!(result, AlgebraicType::Sum(_))); + } + #[test] + fn register_schedule_at_is_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let result = reg.register_type(make_schedule_at(), "MyScheduleAt"); + assert!(matches!(result, AlgebraicType::Sum(_))); + } + + #[test] + fn register_special_identity_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("__identity__")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + assert!(matches!(reg.register_type(ty, ""), AlgebraicType::Product(_))); + } + #[test] + fn register_special_connection_id_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("__connection_id__")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + assert!(matches!(reg.register_type(ty, ""), AlgebraicType::Product(_))); + } + #[test] + fn register_special_timestamp_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("__timestamp_micros_since_unix_epoch__")), + algebraic_type: AlgebraicType::I64, + }] + .into(), + }); + assert!(matches!(reg.register_type(ty, ""), AlgebraicType::Product(_))); + } + #[test] + fn register_special_duration_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("__time_duration_micros__")), + algebraic_type: AlgebraicType::I64, + }] + .into(), + }); + assert!(matches!(reg.register_type(ty, ""), AlgebraicType::Product(_))); + } + #[test] + fn register_special_uuid_inlined() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("__uuid__")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + assert!(matches!(reg.register_type(ty, ""), AlgebraicType::Product(_))); + } + + // ============================================================ + // register_type — error paths + // ============================================================ + + #[test] + fn register_missing_name_sets_error() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + let result = reg.register_type(ty, ""); + assert!(matches!(result, AlgebraicType::U8)); + assert!(reg.has_error()); + assert!(reg.error().unwrap().message.contains("Missing type name")); + } + + #[test] + fn register_circular_ref_sets_error() { + let mut reg = ModuleTypeRegistration::new(); + reg.types_being_registered.insert("Recursive".into()); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + let result = reg.register_type(ty, "Recursive"); + assert!(matches!(result, AlgebraicType::U8)); + assert!(reg.has_error()); + assert!(reg.error().unwrap().message.contains("Recursive type reference")); + } + + // ============================================================ + // convert_special_type + // ============================================================ + + #[test] + fn convert_special_converts_fields() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("__identity__")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + let result = reg.convert_special_type(&ty); + assert!(matches!(result, AlgebraicType::Product(ref p) if p.elements.len() == 1)); + } + #[test] + fn convert_special_non_product_fallback() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!( + reg.convert_special_type(&AlgebraicType::U8), + AlgebraicType::U8 + )); + } + + // ============================================================ + // convert_inline_sum + // ============================================================ + + #[test] + fn convert_inline_sum_variants() { + let mut reg = ModuleTypeRegistration::new(); + let sum = make_option(); + let result = reg.convert_inline_sum(&sum); + assert!(matches!(result, AlgebraicType::Sum(ref s) if s.variants.len() == 2)); + } + #[test] + fn convert_inline_sum_non_sum_fallback() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!( + reg.convert_inline_sum(&AlgebraicType::Bool), + AlgebraicType::U8 + )); + } + + // ============================================================ + // process_product / process_sum + // ============================================================ + + #[test] + fn process_product_preserves_fields() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ + ProductTypeElement { + name: Some(RawIdentifier::new("a")), + algebraic_type: AlgebraicType::U8, + }, + ProductTypeElement { + name: Some(RawIdentifier::new("b")), + algebraic_type: AlgebraicType::String, + }, + ] + .into(), + }); + let result = reg.process_product(&ty); + if let AlgebraicType::Product(ref p) = result { + assert_eq!(p.elements.len(), 2); + assert_eq!(p.elements[0].name.as_ref().map(|r| &**r), Some("a")); + assert_eq!(p.elements[1].name.as_ref().map(|r| &**r), Some("b")); + } else { + panic!("expected Product"); + } + } + #[test] + fn process_product_non_product_fallback() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!(reg.process_product(&AlgebraicType::U8), AlgebraicType::U8)); + } + + #[test] + fn process_sum_preserves_variants() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("A")), + algebraic_type: AlgebraicType::U8, + }, + SumTypeVariant { + name: Some(RawIdentifier::new("B")), + algebraic_type: AlgebraicType::I32, + }, + ] + .into(), + }); + let result = reg.process_sum(&ty); + if let AlgebraicType::Sum(ref s) = result { + assert_eq!(s.variants.len(), 2); + assert_eq!(s.variants[0].name.as_ref().map(|r| &**r), Some("A")); + assert_eq!(s.variants[1].name.as_ref().map(|r| &**r), Some("B")); + } else { + panic!("expected Sum"); + } + } + #[test] + fn process_sum_non_sum_fallback() { + let mut reg = ModuleTypeRegistration::new(); + assert!(matches!(reg.process_sum(&AlgebraicType::Bool), AlgebraicType::U8)); + } + + // ============================================================ + // Complex type registration (structs/enums) + // ============================================================ + + #[test] + fn register_struct_adds_to_typespace() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U32, + }] + .into(), + }); + let result = reg.register_type(ty, "MyStruct"); + assert!(matches!(result, AlgebraicType::Ref(r) if r.0 == 0)); + assert_eq!(reg.typespace.types.len(), 1); + assert_eq!(reg.type_defs.len(), 1); + assert_eq!(&*reg.type_defs[0].source_name.source_name, "MyStruct"); + } + + #[test] + fn register_enum_adds_to_typespace() { + let mut reg = ModuleTypeRegistration::new(); + let ty = AlgebraicType::Sum(SumType { + variants: vec![ + SumTypeVariant { + name: Some(RawIdentifier::new("A")), + algebraic_type: AlgebraicType::Product(ProductType { + elements: vec![].into(), + }), + }, + SumTypeVariant { + name: Some(RawIdentifier::new("B")), + algebraic_type: AlgebraicType::Product(ProductType { + elements: vec![].into(), + }), + }, + ] + .into(), + }); + let result = reg.register_type(ty, "MyEnum"); + assert!(matches!(result, AlgebraicType::Ref(r) if r.0 == 0)); + assert_eq!(reg.typespace.types.len(), 1); + assert_eq!(reg.type_defs.len(), 1); + } + + #[test] + fn register_nested_struct() { + let mut reg = ModuleTypeRegistration::new(); + let inner = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("y")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }); + let outer = AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("inner")), + algebraic_type: inner.clone(), + }] + .into(), + }); + + // First register Inner + reg.register_type(inner, "Inner"); + // Then register Outer + let result = reg.register_type(outer, "Outer"); + assert!(matches!(result, AlgebraicType::Ref(r) if r.0 == 1)); + assert_eq!(reg.typespace.types.len(), 2); + assert_eq!(reg.type_defs.len(), 2); + } + + // ============================================================ + // build / serialize + // ============================================================ + + #[test] + fn build_module_def_contains_sections() { + let mut reg = ModuleTypeRegistration::new(); + reg.register_type( + AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }), + "TestStruct", + ); + let module = reg.build_module_def(); + assert_eq!(module.sections.len(), 2); + assert!(matches!(&module.sections[0], RawModuleDefV10Section::Typespace(_))); + assert!(matches!(&module.sections[1], RawModuleDefV10Section::Types(_))); + } + + #[test] + fn serialize_module_def_produces_bytes() { + let mut reg = ModuleTypeRegistration::new(); + reg.register_type( + AlgebraicType::Product(ProductType { + elements: vec![ProductTypeElement { + name: Some(RawIdentifier::new("x")), + algebraic_type: AlgebraicType::U8, + }] + .into(), + }), + "TestStruct", + ); + let bytes = serialize_module_def(®); + assert!(!bytes.is_empty()); + } + + #[test] + fn serialize_empty_module() { + let reg = ModuleTypeRegistration::new(); + let bytes = serialize_module_def(®); + assert!(!bytes.is_empty()); + } +} diff --git a/crates/update/src/cli/uninstall.rs b/crates/update/src/cli/uninstall.rs index 0ec1f69caf1..3c9e14dbe31 100644 --- a/crates/update/src/cli/uninstall.rs +++ b/crates/update/src/cli/uninstall.rs @@ -28,10 +28,97 @@ impl Uninstall { Ok(None) => {} Err(e) => tracing::warn!("{e:#}"), } + let dir = paths.cli_bin_dir.version_dir(&version); + if !dir.0.exists() { + anyhow::bail!("v{version} is not installed"); + } if yes.confirm(format!("Uninstall v{version}?"))? { - let dir = paths.cli_bin_dir.version_dir(&version); - std::fs::remove_dir_all(dir)?; + std::fs::remove_dir_all(&dir)?; } Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use spacetimedb_paths::FromPathUnchecked; + use spacetimedb_paths::RootDir; + + fn make_temp_paths() -> (tempfile::TempDir, SpacetimePaths) { + let tmp = tempfile::tempdir().unwrap(); + let base = tmp.path().join("spacetime"); + std::fs::create_dir_all(&base).unwrap(); + let root = RootDir::from_path_unchecked(base); + let paths = SpacetimePaths::from_root_dir(&root); + (tmp, paths) + } + + #[test] + fn test_uninstall_nonexistent_version_errors_before_prompt() { + let (_tmp, paths) = make_temp_paths(); + let uninstall = Uninstall { + version: "9.9.9".to_owned(), + yes: ForceYes { yes: true }, + }; + let result = uninstall.exec(&paths); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string().contains("9.9.9"), + "error should mention the version number" + ); + assert!( + err.to_string().contains("not installed"), + "error should say 'not installed'" + ); + } + + #[test] + fn test_uninstall_current_version_errors() { + let (_tmp, paths) = make_temp_paths(); + // Create the "current" symlink target so it exists on disk + let current_dir = paths.cli_bin_dir.version_dir("2.0.0"); + std::fs::create_dir_all(¤t_dir.0).unwrap(); + paths.cli_bin_dir.set_current_version("2.0.0").unwrap(); + + let uninstall = Uninstall { + version: "2.0.0".to_owned(), + yes: ForceYes { yes: true }, + }; + let result = uninstall.exec(&paths); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("currently used version"),); + } + + #[test] + fn test_uninstall_current_keyword_errors() { + let (_tmp, paths) = make_temp_paths(); + let uninstall = Uninstall { + version: "current".to_owned(), + yes: ForceYes { yes: true }, + }; + let result = uninstall.exec(&paths); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("cannot remove `current`"),); + } + + #[test] + fn test_uninstall_existing_version_with_yes() { + let (_tmp, paths) = make_temp_paths(); + let version_dir = paths.cli_bin_dir.version_dir("1.0.0"); + std::fs::create_dir_all(&version_dir.0).unwrap(); + // Create a dummy file so we can verify the directory existed + std::fs::write(version_dir.0.join("spacetime"), "dummy").unwrap(); + + assert!(version_dir.0.exists(), "version dir should exist before"); + + let uninstall = Uninstall { + version: "1.0.0".to_owned(), + yes: ForceYes { yes: true }, + }; + uninstall.exec(&paths).unwrap(); + + assert!(!version_dir.0.exists(), "version dir should be removed after uninstall"); + } +}