diff --git a/crates/bindings-macro/src/lib.rs b/crates/bindings-macro/src/lib.rs index 52b13f00267..cb59682c4f3 100644 --- a/crates/bindings-macro/src/lib.rs +++ b/crates/bindings-macro/src/lib.rs @@ -52,6 +52,7 @@ mod sym { symbol!(index); symbol!(init); symbol!(name); + symbol!(on_abort); symbol!(primary_key); symbol!(private); symbol!(public); diff --git a/crates/bindings-macro/src/procedure.rs b/crates/bindings-macro/src/procedure.rs index 8b1f7246a74..beb410ea1ee 100644 --- a/crates/bindings-macro/src/procedure.rs +++ b/crates/bindings-macro/src/procedure.rs @@ -10,6 +10,8 @@ use syn::{ItemFn, LitStr}; pub(crate) struct ProcedureArgs { /// For consistency with reducers: allow specifying a different export name than the Rust function name. name: Option, + /// Optional procedure to invoke if this procedure aborts. + on_abort: Option, } impl ProcedureArgs { @@ -21,6 +23,10 @@ impl ProcedureArgs { check_duplicate(&args.name, &meta)?; args.name = Some(meta.value()?.parse()?); } + sym::on_abort => { + check_duplicate(&args.on_abort, &meta)?; + args.on_abort = Some(meta.value()?.parse()?); + } }); Ok(()) }) @@ -34,6 +40,10 @@ pub(crate) fn procedure_impl(args: ProcedureArgs, original_function: &ItemFn) -> let vis = &original_function.vis; let procedure_name = args.name.unwrap_or_else(|| ident_to_litstr(func_name)); + let on_abort = match args.on_abort { + Some(lit) => quote!(Some(#lit)), + None => quote!(None), + }; assert_only_lifetime_generics(original_function, "procedures")?; @@ -115,6 +125,9 @@ pub(crate) fn procedure_impl(args: ProcedureArgs, original_function: &ItemFn) -> /// The name of this function const NAME: &'static str = #procedure_name; + /// The name of the on-abort handler, if any. + const ON_ABORT: Option<&'static str> = #on_abort; + /// The parameter names of this function const ARG_NAMES: &'static [Option<&'static str>] = &[#(#opt_arg_names),*]; diff --git a/crates/bindings/src/rt.rs b/crates/bindings/src/rt.rs index 4c2ec69254d..7c8d0c8caaa 100644 --- a/crates/bindings/src/rt.rs +++ b/crates/bindings/src/rt.rs @@ -158,6 +158,9 @@ pub trait FnInfo { /// The name of the function. const NAME: &'static str; + /// The name of the on-abort handler, if any. + const ON_ABORT: Option<&'static str> = None; + /// The lifecycle of the function, if there is one. const LIFECYCLE: Option = None; @@ -799,7 +802,9 @@ where register_describer(|module| { let params = A::schema::(&mut module.inner); let ret_ty = ::make_type(&mut module.inner); - module.inner.add_procedure(I::NAME, params, ret_ty); + module + .inner + .add_procedure(I::NAME, params, ret_ty, I::ON_ABORT.map(Into::into)); module.procedures.push(I::INVOKE); }) } diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index d265191b9b8..d098abcf50e 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -12,10 +12,10 @@ use crate::host::module_host::{ ClientConnectedError, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall, ModuleInfo, RefInstance, ViewCallResult, ViewCommand, ViewCommandResult, ViewOutcome, }; -use crate::host::scheduler::{CallScheduledFunctionResult, ScheduledFunctionParams}; +use crate::host::scheduler::{CallScheduledFunctionResult, ScheduledFunctionParams, Scheduler}; use crate::host::{ - ArgsTuple, ModuleHost, ProcedureCallError, ProcedureCallResult, ReducerCallError, ReducerCallResult, ReducerId, - ReducerOutcome, Scheduler, UpdateDatabaseResult, + ArgsTuple, FunctionArgs, ModuleHost, ProcedureCallError, ProcedureCallResult, ReducerCallError, ReducerCallResult, + ReducerId, ReducerOutcome, UpdateDatabaseResult, }; use crate::identity::Identity; use crate::messages::control_db::HostType; @@ -493,6 +493,7 @@ impl WasmModuleInstance { pub struct InstanceCommon { info: Arc, energy_monitor: Arc, + scheduler: Scheduler, allocated_memory: usize, metric_wasm_memory_bytes: IntGauge, vm_metrics: AllVmMetrics, @@ -507,6 +508,7 @@ impl InstanceCommon { info: module.info(), vm_metrics, energy_monitor: module.energy_monitor(), + scheduler: module.scheduler().clone(), // Will be updated on the first reducer call. allocated_memory: 0, metric_wasm_memory_bytes: WORKER_METRICS @@ -519,6 +521,10 @@ impl InstanceCommon { self.info.clone() } + pub(crate) fn scheduler(&self) -> &Scheduler { + &self.scheduler + } + #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn update_database( &mut self, @@ -727,6 +733,15 @@ impl InstanceCommon { let trapped = call_result.is_err(); + if trapped { + if let Some(handler) = procedure_def.on_abort.as_ref() { + self.scheduler().volatile_nonatomic_schedule_immediate( + handler.to_string(), + FunctionArgs::Bsatn(args.get_bsatn().clone()), + ); + } + } + let result = match call_result { Err(err) => { inst.log_traceback("procedure", &procedure_def.name, &err); diff --git a/crates/lib/src/db/raw_def/v10.rs b/crates/lib/src/db/raw_def/v10.rs index 4f060f4ff2b..7ec9ddd1098 100644 --- a/crates/lib/src/db/raw_def/v10.rs +++ b/crates/lib/src/db/raw_def/v10.rs @@ -342,6 +342,9 @@ pub struct RawProcedureDefV10 { /// it should be registered in the typespace and indirected through an [`AlgebraicType::Ref`]. pub return_type: AlgebraicType, + /// The name of the procedure to invoke if this procedure aborts. + pub on_abort: Option, + /// Whether this procedure is callable from clients or is internal-only. pub visibility: FunctionVisibility, } @@ -927,11 +930,13 @@ impl RawModuleDefV10Builder { source_name: impl Into, params: ProductType, return_type: AlgebraicType, + on_abort: Option, ) { self.procedures_mut().push(RawProcedureDefV10 { source_name: source_name.into(), params, return_type, + on_abort, visibility: FunctionVisibility::ClientCallable, }) } diff --git a/crates/lib/src/db/raw_def/v9.rs b/crates/lib/src/db/raw_def/v9.rs index 20b32250d89..7b060cc125d 100644 --- a/crates/lib/src/db/raw_def/v9.rs +++ b/crates/lib/src/db/raw_def/v9.rs @@ -592,6 +592,9 @@ pub struct RawProcedureDefV9 { /// This `ProductType` need not be registered in the typespace. pub params: ProductType, + /// The name of the procedure to invoke if this procedure aborts. + pub on_abort: Option, + /// The type of the return value. /// /// If this is a user-defined product or sum type, @@ -784,8 +787,9 @@ impl RawModuleDefV9Builder { pub fn add_procedure( &mut self, name: impl Into, - params: spacetimedb_sats::ProductType, - return_type: spacetimedb_sats::AlgebraicType, + params: ProductType, + return_type: AlgebraicType, + on_abort: Option, ) { self.module .misc_exports @@ -793,6 +797,7 @@ impl RawModuleDefV9Builder { name: name.into(), params, return_type, + on_abort, })) } diff --git a/crates/schema/src/def.rs b/crates/schema/src/def.rs index 3469b03d828..f3e18251096 100644 --- a/crates/schema/src/def.rs +++ b/crates/schema/src/def.rs @@ -1645,6 +1645,9 @@ pub struct ProcedureDef { /// and indirected through an [`AlgebraicTypeUse::Ref`]. pub return_type_for_generate: AlgebraicTypeUse, + /// The name of the procedure to invoke if this procedure aborts. + pub on_abort: Option, + /// The visibility of this procedure. pub visibility: FunctionVisibility, } @@ -1655,6 +1658,7 @@ impl From for RawProcedureDefV9 { name: val.name.into(), params: val.params, return_type: val.return_type, + on_abort: val.on_abort.map(Into::into), } } } @@ -1665,6 +1669,7 @@ impl From for RawProcedureDefV10 { source_name: val.name.into(), params: val.params, return_type: val.return_type, + on_abort: val.on_abort.map(Into::into), visibility: val.visibility.into(), } } diff --git a/crates/schema/src/def/validate/v10.rs b/crates/schema/src/def/validate/v10.rs index 7c8fe940106..383af4f74d0 100644 --- a/crates/schema/src/def/validate/v10.rs +++ b/crates/schema/src/def/validate/v10.rs @@ -4,8 +4,8 @@ use spacetimedb_lib::de::DeserializeSeed as _; use spacetimedb_sats::{Typespace, WithTypespace}; use crate::def::validate::v9::{ - check_function_names_are_unique, check_scheduled_functions_exist, generate_schedule_name, identifier, - CoreValidator, TableValidator, ViewValidator, + check_function_names_are_unique, check_procedure_on_abort_handlers, check_scheduled_functions_exist, + generate_schedule_name, identifier, CoreValidator, TableValidator, ViewValidator, }; use crate::def::*; use crate::error::ValidationError; @@ -164,6 +164,7 @@ pub fn validate(def: RawModuleDefV10) -> Result { // Attach schedules to their respective tables attach_schedules_to_tables(&mut tables, schedules)?; + check_procedure_on_abort_handlers(&procedures)?; check_scheduled_functions_exist(&mut tables, &reducers, &procedures)?; change_scheduled_functions_and_lifetimes_visibility(&tables, &mut reducers, &mut procedures)?; @@ -519,6 +520,7 @@ impl<'a> ModuleValidatorV10<'a> { source_name, params, return_type, + on_abort, visibility, } = procedure_def; @@ -538,9 +540,10 @@ impl<'a> ModuleValidatorV10<'a> { ); let name_result = identifier(source_name); + let on_abort = on_abort.map(identifier).transpose(); - let (name_result, params_for_generate, return_type_for_generate) = - (name_result, params_for_generate, return_type_for_generate).combine_errors()?; + let (name_result, params_for_generate, return_type_for_generate, on_abort) = + (name_result, params_for_generate, return_type_for_generate, on_abort).combine_errors()?; Ok(ProcedureDef { name: name_result, @@ -551,6 +554,7 @@ impl<'a> ModuleValidatorV10<'a> { }, return_type, return_type_for_generate, + on_abort, visibility: visibility.into(), }) } @@ -1491,8 +1495,13 @@ mod tests { fn duplicate_procedure_names() { let mut builder = RawModuleDefV10Builder::new(); - builder.add_procedure("foo", [("i", AlgebraicType::I32)].into(), AlgebraicType::unit()); - builder.add_procedure("foo", [("name", AlgebraicType::String)].into(), AlgebraicType::unit()); + builder.add_procedure("foo", [("i", AlgebraicType::I32)].into(), AlgebraicType::unit(), None); + builder.add_procedure( + "foo", + [("name", AlgebraicType::String)].into(), + AlgebraicType::unit(), + None, + ); let result: Result = builder.finish().try_into(); @@ -1506,7 +1515,7 @@ mod tests { let mut builder = RawModuleDefV10Builder::new(); builder.add_reducer("foo", [("i", AlgebraicType::I32)].into()); - builder.add_procedure("foo", [("i", AlgebraicType::I32)].into(), AlgebraicType::unit()); + builder.add_procedure("foo", [("i", AlgebraicType::I32)].into(), AlgebraicType::unit(), None); let result: Result = builder.finish().try_into(); diff --git a/crates/schema/src/def/validate/v9.rs b/crates/schema/src/def/validate/v9.rs index 03c9e6f8338..c3b98c31f01 100644 --- a/crates/schema/src/def/validate/v9.rs +++ b/crates/schema/src/def/validate/v9.rs @@ -131,6 +131,7 @@ pub fn validate(def: RawModuleDefV9) -> Result { check_non_procedure_misc_exports(misc_exports, &validator, &mut tables), ) .combine_errors()?; + check_procedure_on_abort_handlers(&procedures)?; check_scheduled_functions_exist(&mut tables, &reducers, &procedures)?; Ok((tables, types, reducers, procedures, views)) }); @@ -376,6 +377,7 @@ impl ModuleValidatorV9<'_> { name, params, return_type, + on_abort, } = procedure_def; let params_for_generate = @@ -396,9 +398,11 @@ impl ModuleValidatorV9<'_> { // Procedures share the "function namespace" with reducers. // Uniqueness is validated in a later pass, in `check_function_names_are_unique`. let name = identifier(name); + let on_abort = on_abort.map(identifier).transpose(); let (name, params_for_generate, return_type_for_generate) = (name, params_for_generate, return_type_for_generate).combine_errors()?; + let on_abort = on_abort?; Ok(ProcedureDef { name, @@ -409,6 +413,7 @@ impl ModuleValidatorV9<'_> { }, return_type, return_type_for_generate, + on_abort, visibility: FunctionVisibility::ClientCallable, }) } @@ -1277,6 +1282,35 @@ pub(crate) fn identifier(name: RawIdentifier) -> Result { Identifier::new(name).map_err(|error| ValidationError::IdentifierError { error }.into()) } +/// Check that every procedure's on-abort handler exists and has matching params. +pub(crate) fn check_procedure_on_abort_handlers(procedures: &IndexMap) -> Result<()> { + procedures + .values() + .filter_map(|procedure| procedure.on_abort.as_ref().map(|handler| (procedure, handler))) + .map(|(procedure, handler)| { + let Some(handler_def) = procedures.get(handler) else { + return Err(ValidationError::MissingProcedureOnAbortHandler { + procedure: procedure.name.clone(), + handler: handler.clone(), + } + .into()); + }; + + if handler_def.params == procedure.params { + Ok(()) + } else { + Err(ValidationError::ProcedureOnAbortParamsMismatch { + procedure: procedure.name.clone(), + handler: handler.clone(), + expected: procedure.params.clone().into(), + actual: handler_def.params.clone().into(), + } + .into()) + } + }) + .collect_all_errors() +} + /// Check that every [`ScheduleDef`]'s `function_name` refers to a real reducer or procedure /// and that the function's arguments are appropriate for the table, /// then record the scheduled function's [`FunctionKind`] in the [`ScheduleDef`]. @@ -2206,8 +2240,18 @@ mod tests { fn duplicate_procedure_names() { let mut builder = RawModuleDefV9Builder::new(); - builder.add_procedure("foo", [("i", AlgebraicType::I32)].into(), AlgebraicType::unit()); - builder.add_procedure("foo", [("name", AlgebraicType::String)].into(), AlgebraicType::unit()); + builder.add_procedure( + "foo", + [("i", AlgebraicType::I32)].into(), + AlgebraicType::unit(), + /* on_abort */ None, + ); + builder.add_procedure( + "foo", + [("name", AlgebraicType::String)].into(), + AlgebraicType::unit(), + /* on_abort */ None, + ); let result: Result = builder.finish().try_into(); @@ -2221,7 +2265,12 @@ mod tests { let mut builder = RawModuleDefV9Builder::new(); builder.add_reducer("foo", [("i", AlgebraicType::I32)].into(), None); - builder.add_procedure("foo", [("i", AlgebraicType::I32)].into(), AlgebraicType::unit()); + builder.add_procedure( + "foo", + [("i", AlgebraicType::I32)].into(), + AlgebraicType::unit(), + /* on_abort */ None, + ); let result: Result = builder.finish().try_into(); diff --git a/crates/schema/src/error.rs b/crates/schema/src/error.rs index 06f284998b5..c81a7c5fb92 100644 --- a/crates/schema/src/error.rs +++ b/crates/schema/src/error.rs @@ -119,6 +119,15 @@ pub enum ValidationError { expected: PrettyAlgebraicType, actual: PrettyAlgebraicType, }, + #[error("Procedure {procedure} specifies on_abort handler {handler} that does not exist")] + MissingProcedureOnAbortHandler { procedure: Identifier, handler: Identifier }, + #[error("Procedure {procedure} on_abort handler {handler} expected params {expected}, but has params {actual}")] + ProcedureOnAbortParamsMismatch { + procedure: Identifier, + handler: Identifier, + expected: PrettyAlgebraicType, + actual: PrettyAlgebraicType, + }, #[error("Table name is reserved for system use: {table}")] TableNameReserved { table: Identifier }, #[error("Row-level security invalid: `{error}`, query: `{sql}")]