diff --git a/plugins/lsps-plugin/Cargo.toml b/plugins/lsps-plugin/Cargo.toml index 67f9297ca825..55dbc82e934a 100644 --- a/plugins/lsps-plugin/Cargo.toml +++ b/plugins/lsps-plugin/Cargo.toml @@ -6,18 +6,24 @@ edition = "2021" [[bin]] name = "cln-lsps-client" path = "src/client.rs" +required-features = ["cln"] [[bin]] name = "cln-lsps-service" path = "src/service.rs" +required-features = ["cln"] + +[features] +default = ["cln"] +cln = ["cln-plugin", "cln-rpc"] [dependencies] anyhow = "1.0" async-trait = "0.1" -bitcoin = "0.32.2" +bitcoin = { version = "0.32", features = ["serde"] } chrono = { version= "0.4.42", features = ["serde"] } -cln-plugin = { workspace = true } -cln-rpc = { workspace = true } +cln-plugin = { workspace = true, optional = true } +cln-rpc = { workspace = true, optional = true } hex = "0.4" log = "0.4" paste = "1.0.15" @@ -25,4 +31,4 @@ rand = "0.9" serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0", features = ["raw_value"] } thiserror = "2.0" -tokio = { version = "1.44", features = ["full"] } +tokio = { version = "1.44", features = ["full", "test-util"] } diff --git a/plugins/lsps-plugin/src/client.rs b/plugins/lsps-plugin/src/client.rs index bdf3475fb0f2..bb15cd7ece9b 100644 --- a/plugins/lsps-plugin/src/client.rs +++ b/plugins/lsps-plugin/src/client.rs @@ -17,11 +17,11 @@ use cln_lsps::{ transport::{MultiplexedTransport, PendingRequests}, }, proto::{ - lsps0::{Msat, LSPS0_MESSAGE_TYPE, LSP_FEATURE_BIT}, + lsps0::{Msat, LSP_FEATURE_BIT}, lsps2::{compute_opening_fee, Lsps2BuyResponse, Lsps2GetInfoResponse, OpeningFeeParams}, }, }; -use cln_plugin::{options, HookBuilder, HookFilter}; +use cln_plugin::options; use cln_rpc::{ model::{ requests::{ @@ -82,10 +82,7 @@ impl ClientState for State { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { if let Some(plugin) = cln_plugin::Builder::new(tokio::io::stdin(), tokio::io::stdout()) - .hook_from_builder( - HookBuilder::new("custommsg", hooks::client_custommsg_hook) - .filters(vec![HookFilter::Int(i64::from(LSPS0_MESSAGE_TYPE))]), - ) + .hook("custommsg", hooks::client_custommsg_hook) .option(OPTION_ENABLED) .rpcmethod( "lsps-listprotocols", @@ -485,7 +482,7 @@ async fn on_lsps_lsps2_invoice( // 5. Approve jit_channel_scid for a jit channel opening. let appr_req = ClnRpcLsps2Approve { lsp_id: req.lsp_id, - jit_channel_scid: buy_res.jit_channel_scid, + jit_channel_scid: buy_res.jit_channel_scid.into(), payment_hash: public_inv.payment_hash.to_string(), client_trusts_lsp: Some(buy_res.client_trusts_lsp), }; @@ -698,6 +695,7 @@ async fn on_openchannel( return Ok(serde_json::json!({ "result": "continue", "mindepth": 0, + "reserve": 0, })); } else { // Not a requested JIT-channel opening, continue. diff --git a/plugins/lsps-plugin/src/cln_adapters/mod.rs b/plugins/lsps-plugin/src/cln_adapters/mod.rs index 063690099ca2..cb162acdcf80 100644 --- a/plugins/lsps-plugin/src/cln_adapters/mod.rs +++ b/plugins/lsps-plugin/src/cln_adapters/mod.rs @@ -3,3 +3,8 @@ pub mod rpc; pub mod sender; pub mod state; pub mod types; + +pub use rpc::{ + ClnActionExecutor, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, + ClnRpcClient, +}; diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 3471d5838d24..aa841779f15c 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -1,13 +1,17 @@ use crate::{ - core::lsps2::provider::{ - Blockheight, BlockheightProvider, DatastoreProvider, LightningProvider, Lsps2OfferProvider, + core::lsps2::{ + actor::ActionExecutor, + provider::{ + ChannelRecoveryInfo, DatastoreProvider, + ForwardActivity, Lsps2PolicyProvider, RecoveryProvider, + }, }, proto::{ - lsps0::Msat, + lsps0::{Msat, ShortChannelId}, lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, + DatastoreEntry, FinalizedDatastoreEntry, Lsps2PolicyBuyRequest, Lsps2PolicyBuyResponse, + Lsps2PolicyGetInfoRequest, Lsps2PolicyGetInfoResponse, OpeningFeeParams, + SessionOutcome, }, }, }; @@ -17,112 +21,499 @@ use bitcoin::secp256k1::PublicKey; use cln_rpc::{ model::{ requests::{ - DatastoreMode, DatastoreRequest, DeldatastoreRequest, FundchannelRequest, - GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, + AddpsbtoutputRequest, CloseRequest, ConnectRequest, DatastoreMode, DatastoreRequest, + DeldatastoreRequest, DisconnectRequest, FundchannelCancelRequest, + FundchannelCompleteRequest, FundchannelStartRequest, FundpsbtRequest, GetinfoRequest, + ListdatastoreRequest, ListforwardsIndex, ListforwardsRequest, ListpeerchannelsRequest, + SendpsbtRequest, SignpsbtRequest, UnreserveinputsRequest, }, - responses::ListdatastoreResponse, + responses::{ListdatastoreResponse, ListforwardsForwardsStatus}, }, - primitives::{Amount, AmountOrAll, ChannelState, Sha256, ShortChannelId}, + primitives::{Amount, AmountOrAll, ChannelState, Feerate, Sha256}, ClnRpc, }; use core::fmt; +use log::warn; use serde::Serialize; use std::path::PathBuf; +use std::str::FromStr; +use std::time::Duration; pub const DS_MAIN_KEY: &'static str = "lsps"; pub const DS_SUB_KEY: &'static str = "lsps2"; +pub const DS_SESSIONS_KEY: &str = "sessions"; +pub const DS_ACTIVE_KEY: &str = "active"; +pub const DS_FINALIZED_KEY: &str = "finalized"; + +// --------------------------------------------------------------------------- +// ClnRpcClient — shared connection helper +// --------------------------------------------------------------------------- #[derive(Clone)] -pub struct ClnApiRpc { +pub struct ClnRpcClient { rpc_path: PathBuf, } -impl ClnApiRpc { +impl ClnRpcClient { pub fn new(rpc_path: PathBuf) -> Self { Self { rpc_path } } - async fn create_rpc(&self) -> Result { + pub async fn create_rpc(&self) -> Result { + // Note: Add retry and backoff, be nicer than just failing. ClnRpc::new(&self.rpc_path).await } + + pub async fn poll_channel_ready( + &self, + channel_id: &Sha256, + timeout: Duration, + interval: Duration, + ) -> Result<()> { + let deadline = tokio::time::Instant::now() + timeout; + loop { + if self.check_channel_normal(channel_id).await? { + return Ok(()); + } + if tokio::time::Instant::now() + interval > deadline { + anyhow::bail!( + "timed out waiting for channel {} to reach CHANNELD_NORMAL", + channel_id + ); + } + tokio::time::sleep(interval).await; + } + } + + pub async fn check_channel_normal(&self, channel_id: &Sha256) -> Result { + let mut rpc = self.create_rpc().await?; + let r = rpc + .call_typed(&ListpeerchannelsRequest { + channel_id: Some(*channel_id), + id: None, + short_channel_id: None, + }) + .await + .with_context(|| "calling listpeerchannels")?; + + Ok(r.channels + .first() + .is_some_and(|ch| ch.state == ChannelState::CHANNELD_NORMAL)) + } + + pub async fn connect_with_retry(&self, peer_id: &str, timeout: Duration) -> Result<()> { + let deadline = tokio::time::Instant::now() + timeout; + let mut backoff = Duration::from_secs(1); + let max_backoff = Duration::from_secs(10); + + loop { + let mut rpc = self.create_rpc().await?; + let res = rpc + .call_typed(&ConnectRequest { + host: None, + port: None, + id: peer_id.to_string(), + }) + .await; + + if res.is_ok() { + return Ok(()); + } + + if tokio::time::Instant::now() + backoff > deadline { + anyhow::bail!("connect to {peer_id} timed out after {timeout:?}"); + } + + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(max_backoff); + } + } + + /// Get the short_channel_id for a channel, needed for listforwards queries. + /// Falls back to alias.local for unconfirmed JIT channels. + pub async fn get_channel_scid( + &self, + channel_id: &str, + ) -> Result> { + let mut rpc = self.create_rpc().await?; + let peers = rpc + .call_typed(&ListpeerchannelsRequest { + channel_id: None, + id: None, + short_channel_id: None, + }) + .await?; + + for ch in &peers.channels { + if let Some(ref cid) = ch.channel_id { + if cid.to_string() == channel_id { + return Ok(ch + .short_channel_id + .or(ch.alias.as_ref().and_then(|a| a.local))); + } + } + } + Ok(None) + } + + pub async fn unreserve_inputs(&self, psbt: &str) -> Result<()> { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&UnreserveinputsRequest { + reserve: None, + psbt: psbt.to_string(), + }) + .await + .with_context(|| "calling unreserveinputs")?; + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// ClnActionExecutor — implements ActionExecutor +// --------------------------------------------------------------------------- + +/// Converts msat to sat, rounding up to avoid underfunding. +fn msat_to_sat_ceil(msat: u64) -> u64 { + msat.div_ceil(1000) +} + +#[derive(Clone)] +pub struct ClnActionExecutor { + rpc: ClnRpcClient, +} + +impl ClnActionExecutor { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } + } + + async fn cleanup_failed_funding(&self, peer_id: &PublicKey, psbt: &str) { + if let Err(e) = self.rpc.unreserve_inputs(psbt).await { + warn!("cleanup: unreserveinputs for psbt={psbt} failed: {e}"); + } + if let Err(e) = self.cancel_fundchannel(peer_id).await { + warn!("cleanup: fundchannel_cancel failed: {e}"); + } + } + + async fn cancel_fundchannel(&self, peer_id: &PublicKey) -> Result<()> { + let mut rpc = self.rpc.create_rpc().await?; + rpc.call_typed(&FundchannelCancelRequest { + id: peer_id.to_owned(), + }) + .await + .with_context(|| "calling fundchannel_cancel")?; + Ok(()) + } } #[async_trait] -impl LightningProvider for ClnApiRpc { - async fn fund_jit_channel( +impl ActionExecutor for ClnActionExecutor { + async fn fund_channel( &self, - peer_id: &PublicKey, - amount: &Msat, - ) -> Result<(Sha256, String)> { - let mut rpc = self.create_rpc().await?; - let res = rpc - .call_typed(&FundchannelRequest { + peer_id: String, + channel_size: Msat, + _opening_fee_params: OpeningFeeParams, + _scid: ShortChannelId, + ) -> anyhow::Result<(String, String)> { + let pk = PublicKey::from_str(&peer_id) + .with_context(|| format!("parsing peer_id '{peer_id}'"))?; + let channel_sat = msat_to_sat_ceil(channel_size.msat()); + + self.rpc.connect_with_retry(&peer_id, Duration::from_secs(90)) + .await?; + + let mut rpc = self.rpc.create_rpc().await?; + let start_res = rpc + .call_typed(&FundchannelStartRequest { + id: pk, + amount: Amount::from_sat(channel_sat), + mindepth: Some(0), + channel_type: Some(vec![12, 46, 50]), // zero_conf channel announce: Some(false), close_to: None, - compact_lease: None, feerate: None, - minconf: None, - mindepth: Some(0), push_msat: None, - request_amt: None, + reserve: Some(Amount::from_sat(0)), + }) + .await + .with_context(|| "calling fundchannel_start")?; + let funding_address = start_res.funding_address; + + // Reserve input and add to tx + let mut rpc = self.rpc.create_rpc().await?; + let fundpsbt_res = match rpc + .call_typed(&FundpsbtRequest { + satoshi: AmountOrAll::Amount(Amount::from_sat(channel_sat)), + feerate: Feerate::Normal, + startweight: 1000, + excess_as_change: Some(true), + locktime: None, + min_witness_weight: None, + minconf: None, + nonwrapped: None, + opening_anchor_channel: None, reserve: None, - channel_type: Some(vec![12, 46, 50]), - utxos: None, - amount: AmountOrAll::Amount(Amount::from_msat(amount.msat())), - id: peer_id.to_owned(), }) .await - .with_context(|| "calling fundchannel")?; - Ok((res.channel_id, res.txid)) + { + Ok(r) => r, + Err(e) => { + self.cancel_fundchannel(&pk).await.ok(); + return Err(anyhow::Error::new(e).context("calling fundpsbt")); + } + }; + + let addout_res = match rpc + .call_typed(&AddpsbtoutputRequest { + satoshi: Amount::from_sat(channel_sat), + initialpsbt: Some(fundpsbt_res.psbt.clone()), + destination: Some(funding_address), + locktime: None, + }) + .await + { + Ok(r) => r, + Err(e) => { + self.cleanup_failed_funding(&pk, &fundpsbt_res.psbt).await; + return Err(anyhow::Error::new(e).context("calling addpsbtoutput")); + } + }; + let psbt = addout_res.psbt; + + let complete_res = match rpc + .call_typed(&FundchannelCompleteRequest { + id: pk, + psbt: psbt.clone(), + withhold: Some(true), + }) + .await + { + Ok(r) => r, + Err(e) => { + self.cleanup_failed_funding(&pk, &psbt).await; + return Err(anyhow::Error::new(e).context("calling fundchannel_complete")); + } + }; + let channel_id = complete_res.channel_id; + + if let Err(e) = self + .rpc + .poll_channel_ready( + &channel_id, + Duration::from_secs(120), + Duration::from_secs(1), + ) + .await + { + self.cleanup_failed_funding(&pk, &psbt).await; + return Err(e); + } + + Ok((channel_id.to_string(), psbt)) } - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> Result { - let mut rpc = self.create_rpc().await?; - let r = rpc + async fn broadcast_tx( + &self, + channel_id: String, + funding_psbt: String, + ) -> anyhow::Result { + // Idempotency: check if funding tx was already broadcast. + let sha = channel_id + .parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + let mut rpc = self.rpc.create_rpc().await?; + let list_res = rpc .call_typed(&ListpeerchannelsRequest { - channel_id: None, - id: Some(peer_id.to_owned()), + channel_id: Some(sha), + id: None, short_channel_id: None, }) .await - .with_context(|| "calling listpeerchannels")?; - - let chs = r - .channels - .iter() - .find(|&ch| ch.channel_id.is_some_and(|id| id == *channel_id)); - if let Some(ch) = chs { - if ch.state == ChannelState::CHANNELD_NORMAL { - return Ok(true); + .with_context(|| "calling listpeerchannels in broadcast_tx")?; + if let Some(ch) = list_res.channels.first() { + let already_broadcast = ch + .funding + .as_ref() + .and_then(|f| f.withheld) + .map(|w| !w) + .unwrap_or(false); + if already_broadcast { + // Tx was already broadcast; return the existing txid as a no-op. + if let Some(txid) = &ch.funding_txid { + return Ok(txid.clone()); + } } } - return Ok(false); + let mut rpc = self.rpc.create_rpc().await?; + let sign_res = rpc + .call_typed(&SignpsbtRequest { + psbt: funding_psbt, + signonly: None, + }) + .await + .with_context(|| "calling signpsbt")?; + let send_res = rpc + .call_typed(&SendpsbtRequest { + psbt: sign_res.signed_psbt, + reserve: None, + }) + .await + .with_context(|| "calling sendpsbt")?; + Ok(send_res.txid) + } + + async fn abandon_session( + &self, + channel_id: String, + funding_psbt: String, + ) -> anyhow::Result<()> { + // Idempotency: check if channel still exists. + if !self.is_channel_alive(&channel_id).await.unwrap_or(false) { + // Channel already gone — no-op. + // TODO: Belt-and-suspenders: scan listpeerchannels for + // orphaned withheld channels not claimed by any session. + return Ok(()); + } + + let close_res = { + let mut rpc = self.rpc.create_rpc().await?; + rpc.call_typed(&CloseRequest { + destination: None, + fee_negotiation_step: None, + force_lease_closed: None, + unilateraltimeout: Some(1), // We didn't even broadcast the channel yet. + wrong_funding: None, + feerange: None, + id: channel_id.clone(), + }) + .await + .with_context(|| format!("calling close for channel_id={channel_id}")) + }; + + if let Err(e) = &close_res { + warn!("abandon_session: close failed for channel_id={channel_id}: {e}"); + } + + let unreserve_res = self.rpc.unreserve_inputs(&funding_psbt).await; + if let Err(e) = &unreserve_res { + warn!("abandon_session: unreserveinputs failed for funding_psbt={funding_psbt}: {e}"); + } + + match (close_res, unreserve_res) { + (Ok(_), Ok(())) => Ok(()), + (Err(close_err), Ok(())) => Err(close_err), + (Ok(_), Err(unreserve_err)) => Err(unreserve_err), + (Err(close_err), Err(unreserve_err)) => Err(anyhow::anyhow!( + "abandon_session failed for channel_id={channel_id}: close failed: {close_err}; unreserveinputs failed for funding_psbt={funding_psbt}: {unreserve_err}" + )), + } + } + + async fn disconnect(&self, peer_id: String) -> anyhow::Result<()> { + let pk = PublicKey::from_str(&peer_id) + .with_context(|| format!("parsing peer_id '{peer_id}'"))?; + let mut rpc = self.rpc.create_rpc().await?; + let _ = rpc + .call_typed(&DisconnectRequest { + id: pk, + force: None, + }) + .await + .with_context(|| "calling disconnect")?; + Ok(()) + } + + async fn is_channel_alive(&self, channel_id: &str) -> anyhow::Result { + let sha = channel_id + .parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + self.rpc.check_channel_normal(&sha).await + } +} + +// --------------------------------------------------------------------------- +// ClnDatastore — implements DatastoreProvider +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct ClnDatastore { + rpc: ClnRpcClient, +} + +impl ClnDatastore { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } + } + + async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { + let mut rpc = self.rpc.create_rpc().await?; + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ]; + + let _ = rpc + .call_typed(&DeldatastoreRequest { + generation: None, + key, + }) + .await; + + Ok(()) } } #[async_trait] -impl DatastoreProvider for ClnApiRpc { +impl DatastoreProvider for ClnDatastore { async fn store_buy_request( &self, scid: &ShortChannelId, peer_id: &PublicKey, opening_fee_params: &OpeningFeeParams, expected_payment_size: &Option, - ) -> Result { - let mut rpc = self.create_rpc().await?; + channel_capacity_msat: &Msat, + ) -> Result { + let created_at = chrono::Utc::now(); + let mut rpc = self.rpc.create_rpc().await?; #[derive(Serialize)] struct BorrowedDatastoreEntry<'a> { peer_id: &'a PublicKey, opening_fee_params: &'a OpeningFeeParams, #[serde(borrow)] expected_payment_size: &'a Option, + channel_capacity_msat: &'a Msat, + created_at: chrono::DateTime, + #[serde(skip_serializing_if = "Option::is_none")] + channel_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + funding_psbt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + funding_txid: Option, + #[serde(skip_serializing_if = "Option::is_none")] + preimage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + forwards_updated_index: &'a Option, + #[serde(skip_serializing_if = "Option::is_none")] + payment_hash: Option, } let ds = BorrowedDatastoreEntry { peer_id, opening_fee_params, expected_payment_size, + channel_capacity_msat, + created_at, + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, + forwards_updated_index: &None, + payment_hash: None, }; let json_str = serde_json::to_string(&ds)?; @@ -134,6 +525,8 @@ impl DatastoreProvider for ClnApiRpc { key: vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), scid.to_string(), ], }; @@ -144,14 +537,28 @@ impl DatastoreProvider for ClnApiRpc { .map_err(anyhow::Error::new) .with_context(|| "calling datastore")?; - Ok(true) + Ok(DatastoreEntry { + peer_id: *peer_id, + opening_fee_params: opening_fee_params.clone(), + expected_payment_size: *expected_payment_size, + channel_capacity_msat: *channel_capacity_msat, + created_at, + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, + forwards_updated_index: None, + payment_hash: None, + }) } async fn get_buy_request(&self, scid: &ShortChannelId) -> Result { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let key = vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), scid.to_string(), ]; let res = rpc @@ -165,62 +572,284 @@ impl DatastoreProvider for ClnApiRpc { Ok(rec) } - async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { - let mut rpc = self.create_rpc().await?; + async fn save_session(&self, scid: &ShortChannelId, entry: &DatastoreEntry) -> Result<()> { + let json_str = serde_json::to_string(entry)?; + let mut rpc = self.rpc.create_rpc().await?; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::CREATE_OR_REPLACE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ], + }) + .await + .with_context(|| "calling datastore for save_session")?; + Ok(()) + } + + async fn finalize_session(&self, scid: &ShortChannelId, outcome: SessionOutcome) -> Result<()> { + let entry = match self.get_buy_request(scid).await { + Ok(e) => e, + Err(e) => { + warn!("finalize_session: active entry for scid={scid} already gone: {e}"); + return Ok(()); + } + }; + + let finalized = FinalizedDatastoreEntry { + entry, + outcome, + finalized_at: chrono::Utc::now(), + }; + let json_str = serde_json::to_string(&finalized)?; + + let mut rpc = self.rpc.create_rpc().await?; let key = vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_FINALIZED_KEY.to_string(), scid.to_string(), ]; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::MUST_CREATE), + string: Some(json_str), + key, + }) + .await + .with_context(|| "calling datastore for finalize_session")?; - let _ = rpc - .call_typed(&DeldatastoreRequest { - generation: None, - key, - }) - .await; - + self.del_buy_request(scid).await?; Ok(()) } + + async fn list_active_sessions(&self) -> Result> { + let mut rpc = self.rpc.create_rpc().await?; + let prefix = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + ]; + let res = rpc + .call_typed(&ListdatastoreRequest { key: Some(prefix) }) + .await + .with_context(|| "calling listdatastore for list_active_sessions")?; + + let mut sessions = Vec::new(); + for ds in &res.datastore { + if let Some(scid_str) = ds.key.last() { + if let Ok(scid) = scid_str.parse::() { + let json_str = ds.string.as_deref().unwrap_or(""); + if let Ok(entry) = serde_json::from_str::(json_str) { + sessions.push((scid, entry)); + } + } + } + } + Ok(sessions) + } +} + +// --------------------------------------------------------------------------- +// ClnPolicyProvider — implements Lsps2PolicyProvider +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct ClnPolicyProvider { + rpc: ClnRpcClient, +} + +impl ClnPolicyProvider { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } + } } #[async_trait] -impl Lsps2OfferProvider for ClnApiRpc { - async fn get_offer( +impl Lsps2PolicyProvider for ClnPolicyProvider { + async fn get_blockheight(&self) -> Result { + let mut rpc = self.rpc.create_rpc().await?; + let info = rpc + .call_typed(&GetinfoRequest {}) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling getinfo")?; + Ok(info.blockheight) + } + + async fn get_info( &self, request: &Lsps2PolicyGetInfoRequest, ) -> Result { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; rpc.call_raw("lsps2-policy-getpolicy", request) .await .context("failed to call lsps2-policy-getpolicy") } - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> Result { - let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getchannelcapacity", params) + async fn buy(&self, request: &Lsps2PolicyBuyRequest) -> Result { + let mut rpc = self.rpc.create_rpc().await?; + rpc.call_raw("lsps2-policy-buy", request) .await .map_err(anyhow::Error::new) - .with_context(|| "calling lsps2-policy-getchannelcapacity") + .with_context(|| "calling lsps2-policy-buy") + } +} + +// --------------------------------------------------------------------------- +// ClnRecoveryProvider — implements RecoveryProvider +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct ClnRecoveryProvider { + rpc: ClnRpcClient, +} + +impl ClnRecoveryProvider { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } } } #[async_trait] -impl BlockheightProvider for ClnApiRpc { - async fn get_blockheight(&self) -> Result { - let mut rpc = self.create_rpc().await?; - let info = rpc - .call_typed(&GetinfoRequest {}) +impl RecoveryProvider for ClnRecoveryProvider { + async fn get_forward_activity(&self, channel_id: &str) -> Result { + // Check historical forwards via listforwards using out_channel filter. + let scid = match self.rpc.get_channel_scid(channel_id).await? { + Some(s) => s, + None => { + // Channel has no scid yet — no forwards possible. + return Ok(ForwardActivity::NoForwards); + } + }; + + let mut rpc = self.rpc.create_rpc().await?; + let fwd_res = rpc + .call_typed(&ListforwardsRequest { + in_channel: None, + index: Some(ListforwardsIndex::UPDATED), + limit: None, + out_channel: Some(scid), + start: None, + status: None, + }) .await - .map_err(anyhow::Error::new) - .with_context(|| "calling getinfo")?; - Ok(info.blockheight) + .with_context(|| "calling listforwards in get_forward_activity")?; + + if fwd_res.forwards.is_empty() { + return Ok(ForwardActivity::NoForwards); + } + + let mut has_offered = false; + for fwd in &fwd_res.forwards { + match fwd.status { + ListforwardsForwardsStatus::SETTLED => { + return Ok(ForwardActivity::Settled); + } + ListforwardsForwardsStatus::OFFERED => { + has_offered = true; + } + ListforwardsForwardsStatus::FAILED | ListforwardsForwardsStatus::LOCAL_FAILED => {} + } + } + + if has_offered { + return Ok(ForwardActivity::Offered); + } + + // All forwards failed. + Ok(ForwardActivity::AllFailed) + } + + async fn get_channel_recovery_info(&self, channel_id: &str) -> Result { + let sha = channel_id + .parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + let mut rpc = self.rpc.create_rpc().await?; + let list_res = rpc + .call_typed(&ListpeerchannelsRequest { + channel_id: Some(sha), + id: None, + short_channel_id: None, + }) + .await + .with_context(|| "calling listpeerchannels in get_channel_recovery_info")?; + + match list_res.channels.first() { + None => Ok(ChannelRecoveryInfo { + exists: false, + withheld: false, + }), + Some(ch) => { + let withheld = ch + .funding + .as_ref() + .and_then(|f| f.withheld) + .unwrap_or(false); + Ok(ChannelRecoveryInfo { + exists: true, + withheld, + }) + } + } } + + async fn close_and_unreserve(&self, channel_id: &str, funding_psbt: &str) -> Result<()> { + let sha = channel_id.parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + if !self.rpc.check_channel_normal(&sha).await.unwrap_or(false) { + return Ok(()); + } + + let close_res = { + let mut rpc = self.rpc.create_rpc().await?; + rpc.call_typed(&CloseRequest { + destination: None, + fee_negotiation_step: None, + force_lease_closed: None, + unilateraltimeout: Some(1), + wrong_funding: None, + feerange: None, + id: channel_id.to_string(), + }) + .await + .with_context(|| format!("calling close for channel_id={channel_id}")) + }; + + if let Err(e) = &close_res { + warn!("close_and_unreserve: close failed for channel_id={channel_id}: {e}"); + } + + let unreserve_res = self.rpc.unreserve_inputs(funding_psbt).await; + if let Err(e) = &unreserve_res { + warn!("close_and_unreserve: unreserveinputs failed: {e}"); + } + + match (close_res, unreserve_res) { + (Ok(_), Ok(())) => Ok(()), + (Err(e), Ok(())) => Err(e), + (Ok(_), Err(e)) => Err(e), + (Err(ce), Err(ue)) => Err(anyhow::anyhow!( + "close_and_unreserve failed: close: {ce}; unreserve: {ue}" + )), + } + } + } +// --------------------------------------------------------------------------- +// Datastore helpers (standalone) +// --------------------------------------------------------------------------- + #[derive(Debug)] pub enum DsError { /// No datastore entry with this exact key. diff --git a/plugins/lsps-plugin/src/core/lsps2/actor.rs b/plugins/lsps-plugin/src/core/lsps2/actor.rs new file mode 100644 index 000000000000..b45a3cc7ea98 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/actor.rs @@ -0,0 +1,611 @@ +use crate::{ + core::lsps2::{ + event_sink::{EventSink, SessionEventEnvelope}, + provider::DatastoreProvider, + session::{PaymentPart, Session, SessionAction, SessionEvent, SessionInput}, + }, + proto::{ + lsps0::{Msat, ShortChannelId}, + lsps2::{DatastoreEntry, OpeningFeeParams}, + }, +}; +use anyhow::Result; +use async_trait::async_trait; +use bitcoin::hashes::sha256::Hash as PaymentHash; +use bitcoin::hashes::Hash; +use log::{debug, warn}; +use std::{collections::HashMap, sync::Arc, time::Duration}; +use tokio::sync::{mpsc, oneshot}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HtlcResponse { + Forward { + channel_id: String, + fee_msat: u64, + forward_msat: u64, + }, + Fail { + failure_code: &'static str, + }, + Continue, +} + +enum ActorInput { + AddPart { + part: PaymentPart, + reply_tx: oneshot::Sender, + }, + ChannelReady { + channel_id: String, + funding_psbt: String, + }, + FundingFailed, + PaymentSettled { + preimage: Option, + updated_index: Option, + }, + PaymentFailed { + updated_index: Option, + }, + FundingBroadcasted { txid: String }, + NewBlock { + height: u32, + }, + ChannelClosed { + channel_id: String, + }, +} + +/// Adapter for FSM side-effect actions. +#[async_trait] +pub trait ActionExecutor { + async fn fund_channel( + &self, + peer_id: String, + channel_capacity_msat: Msat, + opening_fee_params: OpeningFeeParams, + scid: ShortChannelId, + ) -> Result<(String, String)>; + + async fn abandon_session(&self, channel_id: String, funding_psbt: String) -> Result<()>; + + async fn broadcast_tx(&self, channel_id: String, funding_psbt: String) -> Result; + + async fn disconnect(&self, peer_id: String) -> Result<()>; + + async fn is_channel_alive(&self, channel_id: &str) -> Result; +} + +#[derive(Debug, Clone)] +pub struct ActorInboxHandle { + tx: mpsc::Sender, +} + +impl ActorInboxHandle { + pub async fn add_part(&self, part: PaymentPart) -> Result { + let (reply_tx, rx) = oneshot::channel(); + self.tx.send(ActorInput::AddPart { part, reply_tx }).await?; + Ok(rx.await?) + } + + pub async fn payment_settled( + &self, + preimage: Option, + updated_index: Option, + ) -> Result<()> { + Ok(self + .tx + .send(ActorInput::PaymentSettled { + preimage, + updated_index, + }) + .await?) + } + + pub async fn payment_failed(&self, updated_index: Option) -> Result<()> { + Ok(self + .tx + .send(ActorInput::PaymentFailed { updated_index }) + .await?) + } + + pub async fn new_block(&self, height: u32) -> Result<()> { + Ok(self.tx.send(ActorInput::NewBlock { height }).await?) + } +} + +/// Per-session actor that drives the LSPS2 syncronous session FSM and bridges +/// it to async side effects. +/// +/// It's the runtime boundary around a single `Session`. It owns input ordering, +/// pending HTLC replies, timeout handling, and execution of FMS-emitted side +/// effects and actions. +pub struct SessionActor { + session: Session, + entry: DatastoreEntry, + inbox: mpsc::Receiver, + pending_htlcs: HashMap>, + collect_fired: bool, + channel_poll_handle: Option>, + self_send: mpsc::Sender, + executor: A, + peer_id: String, + collect_timeout_secs: u64, + scid: ShortChannelId, + datastore: D, + event_sink: Arc, +} + +impl + SessionActor +{ + pub fn spawn_session_actor( + session: Session, + entry: DatastoreEntry, + executor: A, + peer_id: String, + collect_timeout_secs: u64, + scid: ShortChannelId, + datastore: D, + event_sink: Arc, + ) -> ActorInboxHandle { + let (tx, inbox) = mpsc::channel(128); // Should we use max_htlcs? + let actor = SessionActor { + session, + entry, + inbox, + pending_htlcs: HashMap::new(), + collect_fired: false, + channel_poll_handle: None, + self_send: tx.clone(), + executor, + peer_id, + collect_timeout_secs, + scid, + datastore, + event_sink, + }; + tokio::spawn(actor.run()); + ActorInboxHandle { tx } + } + + pub fn spawn_recovered_session_actor( + session: Session, + entry: DatastoreEntry, + initial_actions: Vec, + executor: A, + scid: ShortChannelId, + datastore: D, + event_sink: Arc, + ) -> ActorInboxHandle { + let (tx, inbox) = mpsc::channel(128); + let handle = ActorInboxHandle { tx: tx.clone() }; + + let actor = SessionActor { + session, + entry, + inbox, + pending_htlcs: HashMap::new(), + collect_fired: true, + channel_poll_handle: None, + self_send: tx, + executor, + peer_id: String::new(), + collect_timeout_secs: 0, + scid, + datastore, + event_sink, + }; + + tokio::spawn(actor.run_recovered(initial_actions)); + handle + } + + fn dispatch_events(&self, events: Vec) { + let payment_hash = match self.entry.payment_hash.as_deref() { + Some(s) => match s.parse::() { + Ok(h) => h, + Err(e) => { + warn!("malformed payment_hash in datastore for scid={}: {e}", self.scid); + PaymentHash::all_zeros() + } + }, + None => PaymentHash::all_zeros(), + }; + for event in events { + debug!("session event: {:?}", event); + self.event_sink.send(&SessionEventEnvelope { + scid: self.scid, + payment_hash, + event, + }); + } + } + + async fn convert_input(&mut self, input: ActorInput) -> Option { + match input { + ActorInput::AddPart { part, reply_tx } => { + let htlc_id = part.htlc_id; + self.pending_htlcs.insert(htlc_id, reply_tx); + Some(SessionInput::AddPart { part }) + } + ActorInput::ChannelReady { + channel_id, + funding_psbt, + } => { + self.entry.channel_id = Some(channel_id.clone()); + self.entry.funding_psbt = Some(funding_psbt.clone()); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on ChannelReady: {e}"); + } + Some(SessionInput::ChannelReady { + channel_id, + funding_psbt, + }) + } + ActorInput::FundingFailed => Some(SessionInput::FundingFailed), + ActorInput::PaymentSettled { + preimage, + updated_index, + } => { + if let Some(index) = updated_index { + self.entry.forwards_updated_index = Some(index); + } + if let Some(ref pre) = preimage { + self.entry.preimage = Some(pre.clone()); + } + if updated_index.is_some() || preimage.is_some() { + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentSettled: {e}"); + } + } + Some(SessionInput::PaymentSettled) + } + ActorInput::PaymentFailed { updated_index } => { + if let Some(index) = updated_index { + self.entry.forwards_updated_index = Some(index); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentFailed: {e}"); + } + } + Some(SessionInput::PaymentFailed) + } + ActorInput::FundingBroadcasted { txid } => { + self.entry.funding_txid = Some(txid); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on FundingBroadcasted: {e}"); + } + Some(SessionInput::FundingBroadcasted) + } + ActorInput::NewBlock { height } => Some(SessionInput::NewBlock { height }), + ActorInput::ChannelClosed { channel_id } => { + Some(SessionInput::ChannelClosed { channel_id }) + } + } + } + + /// Apply a session input to the FSM and execute resulting actions. + /// Returns `true` if the session reached a terminal state. + fn apply_and_execute(&mut self, input: SessionInput) -> bool { + match self.session.apply(input) { + Ok(result) => { + self.dispatch_events(result.events); + for action in result.actions { + self.execute_action(action); + } + self.session.is_terminal() + } + Err(e) => { + warn!("session FSM error: {e}"); + if self.session.is_terminal() { + self.release_pending_htlcs(); + true + } else { + false + } + } + } + } + + fn start_channel_poll(&mut self, channel_id: String) { + let tx = self.self_send.clone(); + let executor = self.executor.clone(); + self.channel_poll_handle = Some(tokio::spawn(async move { + let interval = Duration::from_secs(5); + loop { + tokio::time::sleep(interval).await; + match executor.is_channel_alive(&channel_id).await { + Ok(true) => continue, + Ok(false) | Err(_) => { + let _ = tx + .send(ActorInput::ChannelClosed { + channel_id: channel_id.clone(), + }) + .await; + break; + } + } + } + })); + } + + fn cancel_channel_poll(&mut self) { + if let Some(handle) = self.channel_poll_handle.take() { + handle.abort(); + } + } + + async fn run(mut self) { + let collect_deadline = tokio::time::sleep( + Duration::from_secs(self.collect_timeout_secs), + ); + tokio::pin!(collect_deadline); + + loop { + tokio::select! { + input = self.inbox.recv() => { + let Some(input) = input else { break }; + let Some(session_input) = self.convert_input(input).await else { + continue; + }; + if self.apply_and_execute(session_input) { + break; + } + } + _ = &mut collect_deadline, if !self.collect_fired => { + self.collect_fired = true; + if self.apply_and_execute(SessionInput::CollectTimeout) { + break; + } + } + } + } + + self.release_pending_htlcs(); + Self::finalize(&self.session, &self.datastore, self.scid).await; + } + + async fn run_recovered( + mut self, + initial_actions: Vec, + ) { + // Execute initial actions (e.g., BroadcastFundingTx for Broadcasting state) + for action in initial_actions { + self.execute_action(action); + } + + if self.session.is_terminal() { + Self::finalize(&self.session, &self.datastore, self.scid).await; + return; + } + + // Main loop: process inbox events from forward_event notifications + loop { + match self.inbox.recv().await { + Some(actor_input) => { + // Only process settlement/failure/broadcast events + let session_input = match &actor_input { + ActorInput::PaymentSettled { .. } + | ActorInput::PaymentFailed { .. } + | ActorInput::FundingBroadcasted { .. } => { + self.convert_input(actor_input).await + } + _ => continue, + }; + + if let Some(input) = session_input { + if self.apply_and_execute(input) { + break; + } + } + } + None => break, + } + } + + Self::finalize(&self.session, &self.datastore, self.scid).await; + } + + async fn finalize(session: &Session, datastore: &D, scid: ShortChannelId) { + if let Some(outcome) = session.outcome() { + if let Err(e) = datastore.finalize_session(&scid, outcome).await { + warn!("finalize_session failed for scid={scid}: {e}"); + } + } + } + + fn execute_action(&mut self, action: SessionAction) { + match action { + SessionAction::FailHtlcs { failure_code } => { + for (_, reply_tx) in self.pending_htlcs.drain() { + let _ = reply_tx.send(HtlcResponse::Fail { failure_code }); + } + } + SessionAction::ForwardHtlcs { parts, channel_id } => { + // First time forwarding HTLCs, we mark the collect timeout as + // fired and start polling the channel for closure: + self.collect_fired = true; + self.start_channel_poll(channel_id.clone()); + for part in &parts { + if let Some(reply_tx) = self.pending_htlcs.remove(&part.htlc_id) { + let _ = reply_tx.send(HtlcResponse::Forward { + channel_id: channel_id.clone(), + fee_msat: part.fee_msat, + forward_msat: part.forward_msat, + }); + } + } + } + SessionAction::FundChannel { + peer_id, + channel_capacity_msat, + opening_fee_params, + } => { + let executor = self.executor.clone(); + let self_tx = self.self_send.clone(); + let scid = self.scid; + tokio::spawn(async move { + match executor + .fund_channel(peer_id, channel_capacity_msat, opening_fee_params, scid) + .await + { + Ok((channel_id, funding_psbt)) => { + let _ = self_tx + .send(ActorInput::ChannelReady { + channel_id, + funding_psbt, + }) + .await; + } + Err(e) => { + warn!("fund_channel failed: {e}"); + let _ = self_tx.send(ActorInput::FundingFailed).await; + } + } + }); + } + SessionAction::FailSession => { + // Is basically a no-op as it is always accompanied with FailHtlcs. + let n = self.release_pending_htlcs(); + debug_assert_eq!(n, 0); + } + SessionAction::AbandonSession { + channel_id, + funding_psbt, + } => { + // Is also basically a no-op as all htlcs should have been + // already forwarded. + let n = self.release_pending_htlcs(); + debug_assert_eq!(n, 0); + + let executor = self.executor.clone(); + tokio::spawn(async move { + if let Err(e) = executor + .abandon_session(channel_id.clone(), funding_psbt.clone()) + .await + { + warn!( + "abandon_session failed (channel_id={}, funding_psbt={}): {}", + channel_id, funding_psbt, e + ); + } + }); + } + SessionAction::BroadcastFundingTx { + channel_id, + funding_psbt, + } => { + self.cancel_channel_poll(); + let executor = self.executor.clone(); + let self_tx = self.self_send.clone(); + tokio::spawn(async move { + match executor + .broadcast_tx(channel_id.clone(), funding_psbt.clone()) + .await + { + Ok(txid) => { + let _ = self_tx.send(ActorInput::FundingBroadcasted { txid }).await; + } + Err(e) => { + warn!( + "broadcast_tx failed (channel_id={}, funding_psbt={}): {}", + channel_id, funding_psbt, e + ); + } + } + }); + } + SessionAction::Disconnect => { + let executor = self.executor.clone(); + let peer_id = self.peer_id.clone(); + tokio::spawn(async move { + if let Err(e) = executor.disconnect(peer_id.clone()).await { + warn!("disconnect failed (peer_id={}): {}", peer_id, e); + } + }); + } + } + } + + fn release_pending_htlcs(&mut self) -> usize { + let n = self.pending_htlcs.iter().len(); + for (_, reply_tx) in self.pending_htlcs.drain() { + let _ = reply_tx.send(HtlcResponse::Continue); + } + n + } +} + +#[async_trait] +impl ActionExecutor for Arc { + async fn fund_channel( + &self, + peer_id: String, + channel_capacity_msat: Msat, + opening_fee_params: OpeningFeeParams, + scid: ShortChannelId, + ) -> Result<(String, String)> { + (**self) + .fund_channel(peer_id, channel_capacity_msat, opening_fee_params, scid) + .await + } + + async fn abandon_session(&self, channel_id: String, funding_psbt: String) -> Result<()> { + (**self).abandon_session(channel_id, funding_psbt).await + } + + async fn broadcast_tx(&self, channel_id: String, funding_psbt: String) -> Result { + (**self).broadcast_tx(channel_id, funding_psbt).await + } + + async fn disconnect(&self, peer_id: String) -> Result<()> { + (**self).disconnect(peer_id).await + } + + async fn is_channel_alive(&self, channel_id: &str) -> Result { + (**self).is_channel_alive(channel_id).await + } +} + +#[async_trait] +impl DatastoreProvider for Arc { + async fn store_buy_request( + &self, + scid: &ShortChannelId, + peer_id: &bitcoin::secp256k1::PublicKey, + offer: &OpeningFeeParams, + expected_payment_size: &Option, + channel_capacity_msat: &Msat, + ) -> Result { + (**self) + .store_buy_request(scid, peer_id, offer, expected_payment_size, channel_capacity_msat) + .await + } + + async fn get_buy_request( + &self, + scid: &ShortChannelId, + ) -> Result { + (**self).get_buy_request(scid).await + } + + async fn save_session( + &self, + scid: &ShortChannelId, + entry: &DatastoreEntry, + ) -> Result<()> { + (**self).save_session(scid, entry).await + } + + async fn finalize_session( + &self, + scid: &ShortChannelId, + outcome: crate::proto::lsps2::SessionOutcome, + ) -> Result<()> { + (**self).finalize_session(scid, outcome).await + } + + async fn list_active_sessions(&self) -> Result> { + (**self).list_active_sessions().await + } +} diff --git a/plugins/lsps-plugin/src/core/lsps2/event_sink.rs b/plugins/lsps-plugin/src/core/lsps2/event_sink.rs new file mode 100644 index 000000000000..0911d28a813e --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/event_sink.rs @@ -0,0 +1,130 @@ +use crate::core::lsps2::session::SessionEvent; +use crate::proto::lsps0::ShortChannelId; +use bitcoin::hashes::sha256::Hash as PaymentHash; +use std::sync::Arc; +use tokio::sync::mpsc; + +#[derive(Debug, Clone)] +pub struct SessionEventEnvelope { + pub scid: ShortChannelId, + pub payment_hash: PaymentHash, + pub event: SessionEvent, +} + +pub trait EventSink: Send + Sync { + fn send(&self, envelope: &SessionEventEnvelope); +} + +pub struct NoopEventSink; +impl EventSink for NoopEventSink { + fn send(&self, _: &SessionEventEnvelope) {} +} + +pub struct CompositeEventSink { + sinks: Vec>, +} + +impl CompositeEventSink { + pub fn new(sinks: Vec>) -> Self { + Self { sinks } + } +} + +impl EventSink for CompositeEventSink { + fn send(&self, envelope: &SessionEventEnvelope) { + for sink in &self.sinks { + sink.send(envelope); + } + } +} + +pub struct ChannelEventSink { + tx: mpsc::UnboundedSender, +} + +impl ChannelEventSink { + pub fn new() -> (Self, mpsc::UnboundedReceiver) { + let (tx, rx) = mpsc::unbounded_channel(); + (Self { tx }, rx) + } +} + +impl EventSink for ChannelEventSink { + fn send(&self, envelope: &SessionEventEnvelope) { + let _ = self.tx.send(envelope.clone()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::lsps2::session::SessionEvent; + use crate::proto::lsps0::ShortChannelId; + use bitcoin::hashes::sha256::Hash as PaymentHash; + use bitcoin::hashes::Hash; + use std::sync::atomic::{AtomicUsize, Ordering}; + + fn test_envelope() -> SessionEventEnvelope { + SessionEventEnvelope { + scid: ShortChannelId::from(100u64 << 40 | 1u64 << 16), + payment_hash: PaymentHash::from_byte_array([1; 32]), + event: SessionEvent::FundingChannel, + } + } + + struct CountingSink(AtomicUsize); + impl CountingSink { + fn new() -> Self { Self(AtomicUsize::new(0)) } + fn count(&self) -> usize { self.0.load(Ordering::SeqCst) } + } + impl EventSink for CountingSink { + fn send(&self, _: &SessionEventEnvelope) { + self.0.fetch_add(1, Ordering::SeqCst); + } + } + + #[test] + fn noop_sink_does_not_panic() { + let sink = NoopEventSink; + sink.send(&test_envelope()); + } + + #[test] + fn composite_fans_out_to_all_sinks() { + let s1 = Arc::new(CountingSink::new()); + let s2 = Arc::new(CountingSink::new()); + let composite = CompositeEventSink::new(vec![ + s1.clone() as Arc, + s2.clone(), + ]); + composite.send(&test_envelope()); + composite.send(&test_envelope()); + assert_eq!(s1.count(), 2); + assert_eq!(s2.count(), 2); + } + + #[test] + fn composite_with_no_sinks_does_not_panic() { + let composite = CompositeEventSink::new(vec![]); + composite.send(&test_envelope()); + } + + #[tokio::test] + async fn channel_sink_delivers_to_receiver() { + let (sink, mut rx) = ChannelEventSink::new(); + let envelope = test_envelope(); + sink.send(&envelope); + sink.send(&envelope); + let received = rx.recv().await.unwrap(); + assert_eq!(received.scid, envelope.scid); + let received2 = rx.recv().await.unwrap(); + assert_eq!(received2.scid, envelope.scid); + } + + #[test] + fn channel_sink_silently_drops_when_receiver_gone() { + let (sink, rx) = ChannelEventSink::new(); + drop(rx); + sink.send(&test_envelope()); // must not panic + } +} diff --git a/plugins/lsps-plugin/src/core/lsps2/handler.rs b/plugins/lsps-plugin/src/core/lsps2/handler.rs deleted file mode 100644 index 88124788a62f..000000000000 --- a/plugins/lsps-plugin/src/core/lsps2/handler.rs +++ /dev/null @@ -1,1367 +0,0 @@ -use crate::{ - core::lsps2::service::Lsps2Handler, - lsps2::{ - cln::{HtlcAcceptedRequest, HtlcAcceptedResponse, TLV_FORWARD_AMT}, - DS_MAIN_KEY, DS_SUB_KEY, - }, - proto::{ - jsonrpc::{RpcError, RpcErrorExt as _}, - lsps0::{LSPS0RpcErrorExt, Msat, ShortChannelId}, - lsps2::{ - compute_opening_fee, - failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, - DatastoreEntry, Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, - Lsps2GetInfoResponse, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, PolicyOpeningFeeParams, Promise, - }, - }, -}; -use anyhow::{Context, Result as AnyResult}; -use async_trait::async_trait; -use bitcoin::{ - hashes::{sha256::Hash as Sha256, Hash as _}, - secp256k1::PublicKey, -}; -use chrono::Utc; -use cln_rpc::{ - model::{ - requests::{ - DatastoreMode, DatastoreRequest, DeldatastoreRequest, FundchannelRequest, - GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, - }, - responses::ListdatastoreResponse, - }, - primitives::{Amount, AmountOrAll, ChannelState}, - ClnRpc, -}; -use log::{debug, warn}; -use rand::{rng, Rng as _}; -use serde::Serialize; -use std::{fmt, path::PathBuf, sync::Arc, time::Duration}; - -const DEFAULT_CLTV_EXPIRY_DELTA: u32 = 144; - -#[derive(Clone)] -pub struct ClnApiRpc { - rpc_path: PathBuf, -} - -impl ClnApiRpc { - pub fn new(rpc_path: PathBuf) -> Self { - Self { rpc_path } - } - - async fn create_rpc(&self) -> AnyResult { - ClnRpc::new(&self.rpc_path).await - } -} - -#[async_trait] -impl LightningProvider for ClnApiRpc { - async fn fund_jit_channel( - &self, - peer_id: &PublicKey, - amount: &Msat, - ) -> AnyResult<(Sha256, String)> { - let mut rpc = self.create_rpc().await?; - let res = rpc - .call_typed(&FundchannelRequest { - announce: Some(false), - close_to: None, - compact_lease: None, - feerate: None, - minconf: None, - mindepth: Some(0), - push_msat: None, - request_amt: None, - reserve: None, - channel_type: Some(vec![12, 46, 50]), - utxos: None, - amount: AmountOrAll::Amount(Amount::from_msat(amount.msat())), - id: peer_id.to_owned(), - }) - .await - .with_context(|| "calling fundchannel")?; - Ok((res.channel_id, res.txid)) - } - - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> AnyResult { - let mut rpc = self.create_rpc().await?; - let r = rpc - .call_typed(&ListpeerchannelsRequest { - id: Some(peer_id.to_owned()), - short_channel_id: None, - }) - .await - .with_context(|| "calling listpeerchannels")?; - - let chs = r - .channels - .iter() - .find(|&ch| ch.channel_id.is_some_and(|id| id == *channel_id)); - if let Some(ch) = chs { - if ch.state == ChannelState::CHANNELD_NORMAL { - return Ok(true); - } - } - - return Ok(false); - } -} - -#[async_trait] -impl DatastoreProvider for ClnApiRpc { - async fn store_buy_request( - &self, - scid: &ShortChannelId, - peer_id: &PublicKey, - opening_fee_params: &OpeningFeeParams, - expected_payment_size: &Option, - ) -> AnyResult { - let mut rpc = self.create_rpc().await?; - #[derive(Serialize)] - struct BorrowedDatastoreEntry<'a> { - peer_id: &'a PublicKey, - opening_fee_params: &'a OpeningFeeParams, - #[serde(borrow)] - expected_payment_size: &'a Option, - } - - let ds = BorrowedDatastoreEntry { - peer_id, - opening_fee_params, - expected_payment_size, - }; - let json_str = serde_json::to_string(&ds)?; - - let ds = DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::MUST_CREATE), - string: Some(json_str), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - scid.to_string(), - ], - }; - - let _ = rpc - .call_typed(&ds) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling datastore")?; - - Ok(true) - } - - async fn get_buy_request(&self, scid: &ShortChannelId) -> AnyResult { - let mut rpc = self.create_rpc().await?; - let key = vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - scid.to_string(), - ]; - let res = rpc - .call_typed(&ListdatastoreRequest { - key: Some(key.clone()), - }) - .await - .with_context(|| "calling listdatastore")?; - - let (rec, _) = deserialize_by_key(&res, key)?; - Ok(rec) - } - - async fn del_buy_request(&self, scid: &ShortChannelId) -> AnyResult<()> { - let mut rpc = self.create_rpc().await?; - let key = vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - scid.to_string(), - ]; - - let _ = rpc - .call_typed(&DeldatastoreRequest { - generation: None, - key, - }) - .await; - - Ok(()) - } -} - -#[async_trait] -impl Lsps2OfferProvider for ClnApiRpc { - async fn get_offer( - &self, - request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult { - let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getpolicy", request) - .await - .context("failed to call lsps2-policy-getpolicy") - } - - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getchannelcapacity", params) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling lsps2-policy-getchannelcapacity") - } -} - -#[async_trait] -impl BlockheightProvider for ClnApiRpc { - async fn get_blockheight(&self) -> AnyResult { - let mut rpc = self.create_rpc().await?; - let info = rpc - .call_typed(&GetinfoRequest {}) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling getinfo")?; - Ok(info.blockheight) - } -} - -#[async_trait] -pub trait Lsps2OfferProvider: Send + Sync { - async fn get_offer( - &self, - request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult; - - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult; -} - -type Blockheight = u32; - -#[async_trait] -pub trait BlockheightProvider: Send + Sync { - async fn get_blockheight(&self) -> AnyResult; -} - -#[async_trait] -pub trait DatastoreProvider: Send + Sync { - async fn store_buy_request( - &self, - scid: &ShortChannelId, - peer_id: &PublicKey, - offer: &OpeningFeeParams, - expected_payment_size: &Option, - ) -> AnyResult; - - async fn get_buy_request(&self, scid: &ShortChannelId) -> AnyResult; - async fn del_buy_request(&self, scid: &ShortChannelId) -> AnyResult<()>; -} - -#[async_trait] -pub trait LightningProvider: Send + Sync { - async fn fund_jit_channel( - &self, - peer_id: &PublicKey, - amount: &Msat, - ) -> AnyResult<(Sha256, String)>; - - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> AnyResult; -} - -pub struct Lsps2ServiceHandler { - pub api: Arc, - pub promise_secret: [u8; 32], -} - -impl Lsps2ServiceHandler { - pub fn new(api: Arc, promise_seret: &[u8; 32]) -> Self { - Lsps2ServiceHandler { - api, - promise_secret: promise_seret.to_owned(), - } - } -} - -async fn get_info_handler( - api: Arc, - secret: &[u8; 32], - request: &Lsps2GetInfoRequest, -) -> std::result::Result { - let res_data = api - .get_offer(&Lsps2PolicyGetInfoRequest { - token: request.token.clone(), - }) - .await - .map_err(|_| RpcError::internal_error("internal error"))?; - - if res_data.client_rejected { - return Err(RpcError::client_rejected("client was rejected")); - }; - - let opening_fee_params_menu = res_data - .policy_opening_fee_params_menu - .iter() - .map(|v| make_opening_fee_params(v, secret)) - .collect::, _>>()?; - - Ok(Lsps2GetInfoResponse { - opening_fee_params_menu, - }) -} - -fn make_opening_fee_params( - v: &PolicyOpeningFeeParams, - secret: &[u8; 32], -) -> Result { - let promise: Promise = v - .get_hmac_hex(secret) - .try_into() - .map_err(|_| RpcError::internal_error("internal error"))?; - Ok(OpeningFeeParams { - min_fee_msat: v.min_fee_msat, - proportional: v.proportional, - valid_until: v.valid_until, - min_lifetime: v.min_lifetime, - max_client_to_self_delay: v.max_client_to_self_delay, - min_payment_size_msat: v.min_payment_size_msat, - max_payment_size_msat: v.max_payment_size_msat, - promise, - }) -} - -#[async_trait] -impl Lsps2Handler - for Lsps2ServiceHandler -{ - async fn handle_get_info( - &self, - request: Lsps2GetInfoRequest, - ) -> std::result::Result { - get_info_handler(self.api.clone(), &self.promise_secret, &request).await - } - - async fn handle_buy( - &self, - peer_id: PublicKey, - request: Lsps2BuyRequest, - ) -> core::result::Result { - let fee_params = request.opening_fee_params; - - // FIXME: In the future we should replace the \`None\` with a meaningful - // value that reflects the inbound capacity for this node from the - // public network for a better pre-condition check on the payment_size. - fee_params.validate(&self.promise_secret, request.payment_size_msat, None)?; - - // Generate a tmp scid to identify jit channel request in htlc. - let blockheight = self - .api - .get_blockheight() - .await - .map_err(|_| RpcError::internal_error("internal error"))?; - - // FIXME: Future task: Check that we don't conflict with any jit scid we - // already handed out -> Check datastore entries. - let jit_scid = ShortChannelId::from(generate_jit_scid(blockheight)); - - let ok = self - .api - .store_buy_request(&jit_scid, &peer_id, &fee_params, &request.payment_size_msat) - .await - .map_err(|_| RpcError::internal_error("internal error"))?; - - if !ok { - return Err(RpcError::internal_error("internal error"))?; - } - - Ok(Lsps2BuyResponse { - jit_channel_scid: jit_scid, - // We can make this configurable if necessary. - lsp_cltv_expiry_delta: DEFAULT_CLTV_EXPIRY_DELTA, - // We can implement the other mode later on as we might have to do - // some additional work on core-lightning to enable this. - client_trusts_lsp: false, - }) - } -} - -fn generate_jit_scid(best_blockheigt: u32) -> u64 { - let mut rng = rng(); - let block = best_blockheigt + 6; // Approx 1 hour in the future and should avoid collision with confirmed channels - let tx_idx: u32 = rng.random_range(0..5000); - let output_idx: u16 = rng.random_range(0..10); - - ((block as u64) << 40) | ((tx_idx as u64) << 16) | (output_idx as u64) -} - -pub struct HtlcAcceptedHookHandler { - api: A, - htlc_minimum_msat: u64, - backoff_listpeerchannels: Duration, -} - -impl HtlcAcceptedHookHandler { - pub fn new(api: A, htlc_minimum_msat: u64) -> Self { - Self { - api, - htlc_minimum_msat, - backoff_listpeerchannels: Duration::from_secs(10), - } - } -} - -impl HtlcAcceptedHookHandler { - pub async fn handle(&self, req: HtlcAcceptedRequest) -> AnyResult { - let scid = match req.onion.short_channel_id { - Some(scid) => scid, - None => { - // We are the final destination of this htlc. - return Ok(HtlcAcceptedResponse::continue_(None, None, None)); - } - }; - - // A) Is this SCID one that we care about? - let ds_rec = match self.api.get_buy_request(&scid).await { - Ok(rec) => rec, - Err(_) => { - return Ok(HtlcAcceptedResponse::continue_(None, None, None)); - } - }; - - // Fixme: Check that we don't have a channel yet with the peer that we await to - // become READY to use. - // --- - - // Fixme: We only accept no-mpp for now, mpp and other flows will be added later on - // Fixme: We continue mpp for now to let the test mock handle the htlc, as we need - // to test the client implementation for mpp payments. - if ds_rec.expected_payment_size.is_some() { - warn!("mpp payments are not implemented yet"); - return Ok(HtlcAcceptedResponse::continue_(None, None, None)); - // return Ok(HtlcAcceptedResponse::fail( - // Some(UNKNOWN_NEXT_PEER.to_string()), - // None, - // )); - } - - // B) Is the fee option menu still valid? - let now = Utc::now(); - if now >= ds_rec.opening_fee_params.valid_until { - // Not valid anymore, remove from DS and fail HTLC. - let _ = self.api.del_buy_request(&scid).await; - return Ok(HtlcAcceptedResponse::fail( - Some(TEMPORARY_CHANNEL_FAILURE.to_string()), - None, - )); - } - - // C) Is the amount in the boundaries of the fee menu? - if req.htlc.amount_msat.msat() < ds_rec.opening_fee_params.min_fee_msat.msat() - || req.htlc.amount_msat.msat() > ds_rec.opening_fee_params.max_payment_size_msat.msat() - { - // No! reject the HTLC. - debug!("amount_msat for scid: {}, was too low or to high", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - - // D) Check that the amount_msat covers the opening fee (only for non-mpp right now) - let opening_fee = if let Some(opening_fee) = compute_opening_fee( - req.htlc.amount_msat.msat(), - ds_rec.opening_fee_params.min_fee_msat.msat(), - ds_rec.opening_fee_params.proportional.ppm() as u64, - ) { - if opening_fee + self.htlc_minimum_msat >= req.htlc.amount_msat.msat() { - debug!("amount_msat for scid: {}, does not cover opening fee", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - opening_fee - } else { - // The computation overflowed. - debug!("amount_msat for scid: {}, was too low or to high", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - }; - - // E) We made it, open a channel to the peer. - let ch_cap_req = Lsps2PolicyGetChannelCapacityRequest { - opening_fee_params: ds_rec.opening_fee_params, - init_payment_size: Msat::from_msat(req.htlc.amount_msat.msat()), - scid, - }; - let ch_cap_res = match self.api.get_channel_capacity(&ch_cap_req).await { - Ok(r) => r, - Err(e) => { - warn!("failed to get channel capacity for scid {}: {}", scid, e); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - }; - - let cap = match ch_cap_res.channel_capacity_msat { - Some(c) => Msat::from_msat(c), - None => { - debug!("policy giver does not allow channel for scid {}", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - }; - - // We take the policy-giver seriously, if the capacity is too low, we - // still try to open the channel. - // Fixme: We may check that the capacity is ge than the - // (amount_msat - opening fee) in the future. - // Fixme: Make this configurable, maybe return the whole request from - // the policy giver? - let channel_id = match self.api.fund_jit_channel(&ds_rec.peer_id, &cap).await { - Ok((channel_id, _)) => channel_id, - Err(_) => { - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - }; - - // F) Wait for the peer to send `channel_ready`. - // Fixme: Use event to check for channel ready, - // Fixme: Check for htlc timeout if peer refuses to send "ready". - // Fixme: handle unexpected channel states. - loop { - match self - .api - .is_channel_ready(&ds_rec.peer_id, &channel_id) - .await - { - Ok(true) => break, - Ok(false) | Err(_) => tokio::time::sleep(self.backoff_listpeerchannels).await, - }; - } - - // G) We got a working channel, deduct fee and forward htlc. - let deducted_amt_msat = req.htlc.amount_msat.msat() - opening_fee; - let mut payload = req.onion.payload.clone(); - payload.set_tu64(TLV_FORWARD_AMT, deducted_amt_msat); - - // It is okay to unwrap the next line as we do not have duplicate entries. - let payload_bytes = payload.to_bytes().unwrap(); - debug!("about to send payload: {:02x?}", &payload_bytes); - - let mut extra_tlvs = req.htlc.extra_tlvs.unwrap_or_default().clone(); - extra_tlvs.set_u64(65537, opening_fee); - let extra_tlvs_bytes = extra_tlvs.to_bytes().unwrap(); - debug!("extra_tlv: {:02x?}", extra_tlvs_bytes); - - Ok(HtlcAcceptedResponse::continue_( - Some(payload_bytes), - Some(channel_id.as_byte_array().to_vec()), - Some(extra_tlvs_bytes), - )) - } -} - -#[derive(Debug)] -pub enum DsError { - /// No datastore entry with this exact key. - NotFound { key: Vec }, - /// Entry existed but had neither `string` nor `hex`. - MissingValue { key: Vec }, - /// JSON parse failed (from `string` or decoded `hex`). - JsonParse { - key: Vec, - source: serde_json::Error, - }, - /// Hex decode failed. - HexDecode { - key: Vec, - source: hex::FromHexError, - }, -} - -impl fmt::Display for DsError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - DsError::NotFound { key } => write!(f, "no datastore entry for key {:?}", key), - DsError::MissingValue { key } => write!( - f, - "datastore entry had neither `string` nor `hex` for key {:?}", - key - ), - DsError::JsonParse { key, source } => { - write!(f, "failed to parse JSON at key {:?}: {}", key, source) - } - DsError::HexDecode { key, source } => { - write!(f, "failed to decode hex at key {:?}: {}", key, source) - } - } - } -} - -impl std::error::Error for DsError {} - -pub fn deserialize_by_key( - resp: &ListdatastoreResponse, - key: K, -) -> std::result::Result<(DatastoreEntry, Option), DsError> -where - K: AsRef<[String]>, -{ - let wanted: &[String] = key.as_ref(); - - let ds = resp - .datastore - .iter() - .find(|d| d.key.as_slice() == wanted) - .ok_or_else(|| DsError::NotFound { - key: wanted.to_vec(), - })?; - - // Prefer `string`, fall back to `hex` - if let Some(s) = &ds.string { - let value = serde_json::from_str::(s).map_err(|e| DsError::JsonParse { - key: ds.key.clone(), - source: e, - })?; - return Ok((value, ds.generation)); - } - - if let Some(hx) = &ds.hex { - let bytes = hex::decode(hx).map_err(|e| DsError::HexDecode { - key: ds.key.clone(), - source: e, - })?; - let value = - serde_json::from_slice::(&bytes).map_err(|e| DsError::JsonParse { - key: ds.key.clone(), - source: e, - })?; - return Ok((value, ds.generation)); - } - - Err(DsError::MissingValue { - key: ds.key.clone(), - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - lsps2::cln::{tlv::TlvStream, HtlcAcceptedResult}, - proto::{ - jsonrpc, - lsps0::Ppm, - lsps2::{Lsps2PolicyGetInfoResponse, PolicyOpeningFeeParams}, - }, - }; - use anyhow::bail; - use chrono::{TimeZone, Utc}; - use cln_rpc::primitives::{Amount, PublicKey}; - use cln_rpc::RpcError as ClnRpcError; - use std::sync::{Arc, Mutex}; - - const PUBKEY: [u8; 33] = [ - 0x02, 0x79, 0xbe, 0x66, 0x7e, 0xf9, 0xdc, 0xbb, 0xac, 0x55, 0xa0, 0x62, 0x95, 0xce, 0x87, - 0x0b, 0x07, 0x02, 0x9b, 0xfc, 0xdb, 0x2d, 0xce, 0x28, 0xd9, 0x59, 0xf2, 0x81, 0x5b, 0x16, - 0xf8, 0x17, 0x98, - ]; - - fn create_peer_id() -> PublicKey { - PublicKey::from_slice(&PUBKEY).expect("Valid pubkey") - } - - /// Build a pair: policy params + buy params with a Promise derived from `secret` - fn params_with_promise(secret: &[u8; 32]) -> (PolicyOpeningFeeParams, OpeningFeeParams) { - let policy = PolicyOpeningFeeParams { - min_fee_msat: Msat(2_000), - proportional: Ppm(10_000), - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1_000_000), - max_payment_size_msat: Msat(100_000_000), - }; - let hex = policy.get_hmac_hex(secret); - let promise: Promise = hex.try_into().expect("hex->Promise"); - let buy = OpeningFeeParams { - min_fee_msat: policy.min_fee_msat, - proportional: policy.proportional, - valid_until: policy.valid_until, - min_lifetime: policy.min_lifetime, - max_client_to_self_delay: policy.max_client_to_self_delay, - min_payment_size_msat: policy.min_payment_size_msat, - max_payment_size_msat: policy.max_payment_size_msat, - promise, - }; - (policy, buy) - } - - #[derive(Clone, Default)] - struct FakeCln { - lsps2_getpolicy_response: Arc>>, - lsps2_getpolicy_error: Arc>>, - blockheight_response: Option, - blockheight_error: Arc>>, - store_buy_request_response: bool, - get_buy_request_response: Arc>>, - get_buy_request_error: Arc>>, - fund_channel_error: Arc>>, - fund_channel_response: Arc>>, - lsps2_getchannelcapacity_response: - Arc>>, - lsps2_getchannelcapacity_error: Arc>>, - } - - #[async_trait] - impl Lsps2OfferProvider for FakeCln { - async fn get_offer( - &self, - _request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult { - if let Some(err) = self.lsps2_getpolicy_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); - }; - if let Some(res) = self.lsps2_getpolicy_response.lock().unwrap().take() { - return Ok(Lsps2PolicyGetInfoResponse { - policy_opening_fee_params_menu: res.policy_opening_fee_params_menu, - client_rejected: false, - }); - }; - panic!("No lsps2 response defined"); - } - - async fn get_channel_capacity( - &self, - _params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - if let Some(err) = self.lsps2_getchannelcapacity_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); - } - if let Some(res) = self - .lsps2_getchannelcapacity_response - .lock() - .unwrap() - .take() - { - return Ok(res); - } - panic!("No lsps2 getchannelcapacity response defined"); - } - } - - #[async_trait] - impl BlockheightProvider for FakeCln { - async fn get_blockheight(&self) -> AnyResult { - if let Some(err) = self.blockheight_error.lock().unwrap().take() { - return Err(err); - }; - if let Some(blockheight) = self.blockheight_response { - return Ok(blockheight); - }; - panic!("No cln getinfo response defined"); - } - } - - #[async_trait] - impl DatastoreProvider for FakeCln { - async fn store_buy_request( - &self, - _scid: &ShortChannelId, - _peer_id: &PublicKey, - _offer: &OpeningFeeParams, - _payment_size_msat: &Option, - ) -> AnyResult { - Ok(self.store_buy_request_response) - } - - async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { - if let Some(err) = self.get_buy_request_error.lock().unwrap().take() { - return Err(err); - } - if let Some(res) = self.get_buy_request_response.lock().unwrap().take() { - return Ok(res); - } else { - bail!("request not found") - } - } - - async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { - Ok(()) - } - } - - #[async_trait] - impl LightningProvider for FakeCln { - async fn fund_jit_channel( - &self, - _peer_id: &PublicKey, - _amount: &Msat, - ) -> AnyResult<(Sha256, String)> { - if let Some(err) = self.fund_channel_error.lock().unwrap().take() { - return Err(err); - } - if let Some(res) = self.fund_channel_response.lock().unwrap().take() { - return Ok(res); - } else { - bail!("request not found") - } - } - - async fn is_channel_ready( - &self, - _peer_id: &PublicKey, - _channel_id: &Sha256, - ) -> AnyResult { - Ok(true) - } - } - - fn create_test_htlc_request( - scid: Option, - amount_msat: u64, - ) -> HtlcAcceptedRequest { - let payload = TlvStream::default(); - - HtlcAcceptedRequest { - onion: crate::lsps2::cln::Onion { - short_channel_id: scid, - payload, - next_onion: vec![], - forward_msat: None, - outgoing_cltv_value: None, - shared_secret: vec![], - total_msat: None, - type_: None, - }, - htlc: crate::lsps2::cln::Htlc { - amount_msat: Amount::from_msat(amount_msat), - cltv_expiry: 100, - cltv_expiry_relative: 10, - payment_hash: vec![], - extra_tlvs: None, - short_channel_id: ShortChannelId::from(123456789u64), - id: 0, - }, - forward_to: None, - } - } - - fn create_test_datastore_entry( - peer_id: PublicKey, - expected_payment_size: Option, - ) -> DatastoreEntry { - let (_, policy) = params_with_promise(&[0u8; 32]); - DatastoreEntry { - peer_id, - opening_fee_params: policy, - expected_payment_size, - } - } - - fn test_promise_secret() -> [u8; 32] { - [0x42; 32] - } - - #[tokio::test] - async fn test_successful_get_info() { - let promise_secret = test_promise_secret(); - let params = Lsps2PolicyGetInfoResponse { - client_rejected: false, - policy_opening_fee_params_menu: vec![PolicyOpeningFeeParams { - min_fee_msat: Msat(2000), - proportional: Ppm(10000), - valid_until: Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1000000), - max_payment_size_msat: Msat(100000000), - }], - }; - let promise = params.policy_opening_fee_params_menu[0].get_hmac_hex(&promise_secret); - let fake = FakeCln::default(); - *fake.lsps2_getpolicy_response.lock().unwrap() = Some(params); - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let request = Lsps2GetInfoRequest { token: None }; - let result = handler.handle_get_info(request).await.unwrap(); - - assert_eq!( - result.opening_fee_params_menu[0].min_payment_size_msat, - Msat(1000000) - ); - assert_eq!( - result.opening_fee_params_menu[0].max_payment_size_msat, - Msat(100000000) - ); - assert_eq!( - result.opening_fee_params_menu[0].promise, - promise.try_into().unwrap() - ); - } - - #[tokio::test] - async fn test_get_info_rpc_error_handling() { - let promise_secret = test_promise_secret(); - let fake = FakeCln::default(); - *fake.lsps2_getpolicy_error.lock().unwrap() = Some(ClnRpcError { - code: Some(-1), - message: "not found".to_string(), - data: None, - }); - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let request = Lsps2GetInfoRequest { token: None }; - let result = handler.handle_get_info(request).await; - - assert!(result.is_err()); - let error = result.unwrap_err(); - assert_eq!(error.code, jsonrpc::INTERNAL_ERROR); - assert!(error.message.contains("internal error")); - } - - #[tokio::test] - async fn buy_ok_fixed_amount() { - let promise_secret = test_promise_secret(); - let mut fake = FakeCln::default(); - fake.blockheight_response = Some(900_000); - fake.store_buy_request_response = true; - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let (_policy, buy) = params_with_promise(&promise_secret); - - // Set payment_size_msat => "MPP+fixed-invoice" mode. - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: Some(Msat(2_000_000)), - }; - let peer_id = create_peer_id(); - - let result = handler.handle_buy(peer_id, request).await.unwrap(); - - assert_eq!(result.lsp_cltv_expiry_delta, DEFAULT_CLTV_EXPIRY_DELTA); - assert!(!result.client_trusts_lsp); - assert!(result.jit_channel_scid.to_u64() > 0); - } - - #[tokio::test] - async fn buy_ok_variable_amount_no_payment_size() { - let promise_secret = test_promise_secret(); - let mut fake = FakeCln::default(); - fake.blockheight_response = Some(900_100); - fake.store_buy_request_response = true; - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let (_policy, buy) = params_with_promise(&promise_secret); - - // No payment_size_msat => "no-MPP+var-invoice" mode. - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: None, - }; - let peer_id = create_peer_id(); - - let result = handler.handle_buy(peer_id, request).await; - - assert!(result.is_ok()); - } - - #[tokio::test] - async fn buy_rejects_invalid_promise_or_past_valid_until_with_201() { - let promise_secret = test_promise_secret(); - let handler = Lsps2ServiceHandler { - api: Arc::new(FakeCln::default()), - promise_secret, - }; - - // Case A: wrong promise (derive with different secret) - let (_policy_wrong, mut buy_wrong) = params_with_promise(&[9u8; 32]); - buy_wrong.valid_until = Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(); // future, so only promise is wrong - let req_wrong = Lsps2BuyRequest { - opening_fee_params: buy_wrong, - payment_size_msat: Some(Msat(2_000_000)), - }; - let peer_id = create_peer_id(); - - let err1 = handler.handle_buy(peer_id, req_wrong).await.unwrap_err(); - assert_eq!(err1.code, 201); - - // Case B: past valid_until - let (_policy, mut buy_past) = params_with_promise(&promise_secret); - buy_past.valid_until = Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(); // past - let req_past = Lsps2BuyRequest { - opening_fee_params: buy_past, - payment_size_msat: Some(Msat(2_000_000)), - }; - let err2 = handler.handle_buy(peer_id, req_past).await.unwrap_err(); - assert_eq!(err2.code, 201); - } - - #[tokio::test] - async fn buy_rejects_when_opening_fee_ge_payment_size_with_202() { - let promise_secret = test_promise_secret(); - let handler = Lsps2ServiceHandler { - api: Arc::new(FakeCln::default()), - promise_secret, - }; - - // Make min_fee already >= payment_size to trigger 202 - let policy = PolicyOpeningFeeParams { - min_fee_msat: Msat(10_000), - proportional: Ppm(0), // no extra percentage - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1), - max_payment_size_msat: Msat(u64::MAX / 2), - }; - let hex = policy.get_hmac_hex(&promise_secret); - let promise: Promise = hex.try_into().unwrap(); - let buy = OpeningFeeParams { - min_fee_msat: policy.min_fee_msat, - proportional: policy.proportional, - valid_until: policy.valid_until, - min_lifetime: policy.min_lifetime, - max_client_to_self_delay: policy.max_client_to_self_delay, - min_payment_size_msat: policy.min_payment_size_msat, - max_payment_size_msat: policy.max_payment_size_msat, - promise, - }; - - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: Some(Msat(9_999)), // strictly less than min_fee => opening_fee >= payment_size - }; - let peer_id = create_peer_id(); - let err = handler.handle_buy(peer_id, request).await.unwrap_err(); - - assert_eq!(err.code, 202); - } - - #[tokio::test] - async fn buy_rejects_on_fee_overflow_with_203() { - let promise_secret = test_promise_secret(); - let handler = Lsps2ServiceHandler { - api: Arc::new(FakeCln::default()), - promise_secret, - }; - - // Choose values likely to overflow if multiplication isn't checked: - // opening_fee = min_fee + payment_size * proportional / 1_000_000 - let policy = PolicyOpeningFeeParams { - min_fee_msat: Msat(u64::MAX / 2), - proportional: Ppm(u32::MAX), // 4_294_967_295 ppm - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1), - max_payment_size_msat: Msat(u64::MAX), - }; - let hex = policy.get_hmac_hex(&promise_secret); - let promise: Promise = hex.try_into().unwrap(); - let buy = OpeningFeeParams { - min_fee_msat: policy.min_fee_msat, - proportional: policy.proportional, - valid_until: policy.valid_until, - min_lifetime: policy.min_lifetime, - max_client_to_self_delay: policy.max_client_to_self_delay, - min_payment_size_msat: policy.min_payment_size_msat, - max_payment_size_msat: policy.max_payment_size_msat, - promise, - }; - - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: Some(Msat(u64::MAX / 2)), - }; - let peer_id = create_peer_id(); - let err = handler.handle_buy(peer_id, request).await.unwrap_err(); - - assert_eq!(err.code, 203); - } - #[tokio::test] - async fn test_htlc_no_scid_continues() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake, 1000); - - // HTLC with no short_channel_id (final destination) - let req = create_test_htlc_request(None, 1000000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Continue); - } - - #[tokio::test] - async fn test_htlc_unknown_scid_continues() { - let fake = FakeCln::default(); - - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let scid = ShortChannelId::from(123456789u64); - - let req = create_test_htlc_request(Some(scid), 1000000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Continue); - } - - #[tokio::test] - async fn test_htlc_expired_fee_menu_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - // Create datastore entry with expired fee menu - let mut ds_entry = create_test_datastore_entry(peer_id, None); - ds_entry.opening_fee_params.valid_until = - Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(); // expired - - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - let req = create_test_htlc_request(Some(scid), 1000000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - TEMPORARY_CHANNEL_FAILURE.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_amount_too_low_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // HTLC amount below minimum - let req = create_test_htlc_request(Some(scid), 100); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_amount_too_high_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // HTLC amount above maximum - let req = create_test_htlc_request(Some(scid), 200_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_amount_doesnt_cover_fee_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // HTLC amount just barely covers minimum fee but not minimum HTLC - let req = create_test_htlc_request(Some(scid), 2500); // min_fee is 2000, htlc_minimum is 1000 - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_channel_capacity_request_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - *fake.lsps2_getchannelcapacity_error.lock().unwrap() = Some(ClnRpcError { - code: Some(-1), - message: "capacity check failed".to_string(), - data: None, - }); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_policy_denies_channel() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // Policy response with no channel capacity (denied) - *fake.lsps2_getchannelcapacity_response.lock().unwrap() = - Some(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: None, - }); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_fund_channel_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - *fake.lsps2_getchannelcapacity_response.lock().unwrap() = - Some(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: Some(50_000_000), - }); - - *fake.fund_channel_error.lock().unwrap() = Some(anyhow::anyhow!("insufficient funds")); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_successful_flow() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler { - api: fake.clone(), - htlc_minimum_msat: 1000, - backoff_listpeerchannels: Duration::from_millis(10), - }; - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - *fake.lsps2_getchannelcapacity_response.lock().unwrap() = - Some(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: Some(50_000_000), - }); - - *fake.fund_channel_response.lock().unwrap() = - Some((*Sha256::from_bytes_ref(&[1u8; 32]), String::default())); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Continue); - - assert!(result.payload.is_some()); - assert!(result.extra_tlvs.is_some()); - assert!(result.forward_to.is_some()); - - // The payload should have the deducted amount - let payload_bytes = result.payload.unwrap(); - let payload_tlv = TlvStream::from_bytes(&payload_bytes).unwrap(); - - // Should contain forward amount. - assert!(payload_tlv.get(TLV_FORWARD_AMT).is_some()); - } - - #[tokio::test] - #[ignore] // We deactivate the mpp check on the experimental server for - // client side checks. - async fn test_htlc_mpp_not_implemented() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - // Create entry with expected_payment_size (MPP mode) - let mut ds_entry = create_test_datastore_entry(peer_id, None); - ds_entry.expected_payment_size = Some(Msat::from_msat(1000000)); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } -} diff --git a/plugins/lsps-plugin/src/core/lsps2/htlc.rs b/plugins/lsps-plugin/src/core/lsps2/htlc.rs deleted file mode 100644 index 6e39cc07cf51..000000000000 --- a/plugins/lsps-plugin/src/core/lsps2/htlc.rs +++ /dev/null @@ -1,802 +0,0 @@ -use crate::{ - core::{ - lsps2::provider::{DatastoreProvider, LightningProvider, Lsps2OfferProvider}, - tlv::{TlvStream, TLV_FORWARD_AMT}, - }, - proto::{ - lsps0::{Msat, ShortChannelId}, - lsps2::{ - compute_opening_fee, - failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, - Lsps2PolicyGetChannelCapacityRequest, - }, - }, -}; -use bitcoin::hashes::sha256::Hash; -use chrono::Utc; -use std::time::Duration; -use thiserror::Error; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum HtlcDecision { - NotOurs, - Forward { - payload: TlvStream, - forward_to: Hash, - extra_tlvs: TlvStream, - }, - - Reject { - reason: RejectReason, - }, -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum RejectReason { - OfferExpired { valid_until: chrono::DateTime }, - AmountBelowMinimum { minimum: Msat }, - AmountAboveMaximum { maximum: Msat }, - InsufficientForFee { fee: Msat }, - FeeOverflow, - PolicyDenied, - FundingFailed, - - // temporarily - MppNotSupported, -} - -impl RejectReason { - pub fn failure_code(&self) -> &'static str { - match self { - Self::OfferExpired { .. } => TEMPORARY_CHANNEL_FAILURE, - _ => UNKNOWN_NEXT_PEER, - } - } -} - -#[derive(Debug, Error)] -pub enum HtlcError { - #[error("failed to query channel capacity: {0}")] - CapacityQuery(#[source] anyhow::Error), - #[error("failed to fund channel: {0}")] - FundChannel(#[source] anyhow::Error), - #[error("channel ready check failed: {0}")] - ChannelReadyCheck(#[source] anyhow::Error), -} - -#[derive(Debug, Clone)] -pub struct Htlc { - pub amount_msat: Msat, - pub extra_tlvs: TlvStream, -} -impl Htlc { - pub fn new(amount_msat: Msat, tlvs: TlvStream) -> Self { - Self { - amount_msat, - extra_tlvs: tlvs, - } - } -} - -#[derive(Debug, Clone)] -pub struct Onion { - pub short_channel_id: ShortChannelId, - pub payload: TlvStream, -} - -pub struct HtlcAcceptedHookHandler { - api: A, - htlc_minimum_msat: u64, - backoff_listpeerchannels: Duration, -} - -impl HtlcAcceptedHookHandler { - pub fn new(api: A, htlc_minimum_msat: u64) -> Self { - Self { - api, - htlc_minimum_msat, - backoff_listpeerchannels: Duration::from_secs(10), - } - } -} -impl HtlcAcceptedHookHandler { - pub async fn handle(&self, htlc: &Htlc, onion: &Onion) -> Result { - // A) Is this SCID one that we care about? - let ds_rec = match self.api.get_buy_request(&onion.short_channel_id).await { - Ok(rec) => rec, - Err(_) => return Ok(HtlcDecision::NotOurs), - }; - - // Fixme: Check that we don't have a channel yet with the peer that we await to - // become READY to use. - // --- - - // Fixme: We only accept no-mpp for now, mpp and other flows will be added later on - // Fixme: We continue mpp for now to let the test mock handle the htlc, as we need - // to test the client implementation for mpp payments. - if ds_rec.expected_payment_size.is_some() { - return Ok(HtlcDecision::Reject { - reason: RejectReason::MppNotSupported, - }); - } - - // B) Is the fee option menu still valid? - if Utc::now() >= ds_rec.opening_fee_params.valid_until { - // Not valid anymore, remove from DS and fail HTLC. - let _ = self.api.del_buy_request(&onion.short_channel_id).await; - return Ok(HtlcDecision::Reject { - reason: RejectReason::OfferExpired { - valid_until: ds_rec.opening_fee_params.valid_until, - }, - }); - } - - // C) Is the amount in the boundaries of the fee menu? - if htlc.amount_msat.msat() < ds_rec.opening_fee_params.min_fee_msat.msat() { - return Ok(HtlcDecision::Reject { - reason: RejectReason::AmountBelowMinimum { - minimum: ds_rec.opening_fee_params.min_fee_msat, - }, - }); - } - - if htlc.amount_msat.msat() > ds_rec.opening_fee_params.max_payment_size_msat.msat() { - return Ok(HtlcDecision::Reject { - reason: RejectReason::AmountAboveMaximum { - maximum: ds_rec.opening_fee_params.max_payment_size_msat, - }, - }); - } - - // D) Check that the amount_msat covers the opening fee (only for non-mpp right now) - let opening_fee = match compute_opening_fee( - htlc.amount_msat.msat(), - ds_rec.opening_fee_params.min_fee_msat.msat(), - ds_rec.opening_fee_params.proportional.ppm() as u64, - ) { - Some(fee) if fee + self.htlc_minimum_msat < htlc.amount_msat.msat() => fee, - Some(fee) => { - return Ok(HtlcDecision::Reject { - reason: RejectReason::InsufficientForFee { - fee: Msat::from_msat(fee), - }, - }) - } - None => { - return Ok(HtlcDecision::Reject { - reason: RejectReason::FeeOverflow, - }) - } - }; - - // E) We made it, open a channel to the peer. - let ch_cap_req = Lsps2PolicyGetChannelCapacityRequest { - opening_fee_params: ds_rec.opening_fee_params, - init_payment_size: htlc.amount_msat, - scid: onion.short_channel_id, - }; - let ch_cap_res = self - .api - .get_channel_capacity(&ch_cap_req) - .await - .map_err(HtlcError::CapacityQuery)?; - - let cap = match ch_cap_res.channel_capacity_msat { - Some(c) => Msat::from_msat(c), - None => { - return Ok(HtlcDecision::Reject { - reason: RejectReason::PolicyDenied, - }) - } - }; - - // We take the policy-giver seriously, if the capacity is too low, we - // still try to open the channel. - // Fixme: We may check that the capacity is ge than the - // (amount_msat - opening fee) in the future. - // Fixme: Make this configurable, maybe return the whole request from - // the policy giver? - let (channel_id, _) = self - .api - .fund_jit_channel(&ds_rec.peer_id, &cap) - .await - .map_err(HtlcError::FundChannel)?; - - // F) Wait for the peer to send `channel_ready`. - // Fixme: Use event to check for channel ready, - // Fixme: Check for htlc timeout if peer refuses to send "ready". - // Fixme: handle unexpected channel states. - loop { - match self - .api - .is_channel_ready(&ds_rec.peer_id, &channel_id) - .await - { - Ok(true) => break, - Ok(false) => tokio::time::sleep(self.backoff_listpeerchannels).await, - Err(e) => return Err(HtlcError::ChannelReadyCheck(e)), - }; - } - - // G) We got a working channel, deduct fee and forward htlc. - let deducted_amt_msat = htlc.amount_msat.msat() - opening_fee; - let mut payload = onion.payload.clone(); - payload.set_tu64(TLV_FORWARD_AMT, deducted_amt_msat); - - let mut extra_tlvs = htlc.extra_tlvs.clone(); - extra_tlvs.set_u64(65537, opening_fee); - - Ok(HtlcDecision::Forward { - payload, - forward_to: channel_id, - extra_tlvs, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::core::tlv::TlvStream; - use crate::proto::lsps0::{Msat, Ppm, ShortChannelId}; - use crate::proto::lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, Promise, - }; - use anyhow::{anyhow, Result as AnyResult}; - use async_trait::async_trait; - use bitcoin::hashes::{sha256::Hash as Sha256, Hash}; - use bitcoin::secp256k1::PublicKey; - use chrono::{TimeZone, Utc}; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::{Arc, Mutex}; - use std::time::Duration; - use std::u64; - - fn test_peer_id() -> PublicKey { - "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" - .parse() - .unwrap() - } - - fn test_scid() -> ShortChannelId { - ShortChannelId::from(123456789u64) - } - - fn test_channel_id() -> Sha256 { - Sha256::from_byte_array([1u8; 32]) - } - - fn valid_opening_fee_params() -> OpeningFeeParams { - OpeningFeeParams { - min_fee_msat: Msat(2_000), - proportional: Ppm(10_000), // 1% - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 2016, - min_payment_size_msat: Msat(1_000_000), - max_payment_size_msat: Msat(100_000_000), - promise: Promise::try_from("test").unwrap(), - } - } - - fn expired_opening_fee_params() -> OpeningFeeParams { - OpeningFeeParams { - valid_until: Utc.with_ymd_and_hms(2000, 1, 1, 0, 0, 0).unwrap(), - ..valid_opening_fee_params() - } - } - - fn test_datastore_entry(expected_payment_size: Option) -> DatastoreEntry { - DatastoreEntry { - peer_id: test_peer_id(), - opening_fee_params: valid_opening_fee_params(), - expected_payment_size, - } - } - - fn test_onion(scid: ShortChannelId, payload: TlvStream) -> Onion { - Onion { - short_channel_id: scid, - payload, - } - } - - fn test_htlc(amount_msat: u64, extra_tlvs: TlvStream) -> Htlc { - Htlc { - amount_msat: Msat::from_msat(amount_msat), - extra_tlvs, - } - } - - #[derive(Default, Clone)] - struct MockApi { - // Datastore - buy_request: Arc>>, - buy_request_error: Arc>, - del_called: Arc, - - // Policy - channel_capacity: Arc>>>, // Some(Some(cap)), Some(None) = denied, None = error - channel_capacity_error: Arc>, - - // Lightning - fund_result: Arc>>, - fund_error: Arc>, - channel_ready: Arc>, - channel_ready_checks: Arc, - } - - impl MockApi { - fn new() -> Self { - Self::default() - } - - fn with_buy_request(self, entry: DatastoreEntry) -> Self { - *self.buy_request.lock().unwrap() = Some(entry); - self - } - - fn with_no_buy_request(self) -> Self { - *self.buy_request_error.lock().unwrap() = true; - self - } - - fn with_channel_capacity(self, capacity_msat: u64) -> Self { - *self.channel_capacity.lock().unwrap() = Some(Some(capacity_msat)); - self - } - - fn with_channel_denied(self) -> Self { - *self.channel_capacity.lock().unwrap() = Some(None); - self - } - - fn with_channel_capacity_error(self) -> Self { - *self.channel_capacity_error.lock().unwrap() = true; - self - } - - fn with_fund_result(self, channel_id: Sha256, txid: &str) -> Self { - *self.fund_result.lock().unwrap() = Some((channel_id, txid.to_string())); - self - } - - fn with_fund_error(self) -> Self { - *self.fund_error.lock().unwrap() = true; - self - } - - fn with_channel_ready(self, ready: bool) -> Self { - *self.channel_ready.lock().unwrap() = ready; - self - } - - fn del_call_count(&self) -> usize { - self.del_called.load(Ordering::SeqCst) - } - - fn channel_ready_check_count(&self) -> usize { - self.channel_ready_checks.load(Ordering::SeqCst) - } - } - - #[async_trait] - impl DatastoreProvider for MockApi { - async fn store_buy_request( - &self, - _scid: &ShortChannelId, - _peer_id: &PublicKey, - _fee_params: &OpeningFeeParams, - _payment_size: &Option, - ) -> AnyResult { - unimplemented!("not needed for HTLC tests") - } - - async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { - if *self.buy_request_error.lock().unwrap() { - return Err(anyhow!("not found")); - } - self.buy_request - .lock() - .unwrap() - .clone() - .ok_or_else(|| anyhow!("not found")) - } - - async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { - self.del_called.fetch_add(1, Ordering::SeqCst); - Ok(()) - } - } - - #[async_trait] - impl Lsps2OfferProvider for MockApi { - async fn get_offer( - &self, - _request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult { - unimplemented!("not needed for HTLC tests") - } - - async fn get_channel_capacity( - &self, - _params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - if *self.channel_capacity_error.lock().unwrap() { - return Err(anyhow!("capacity error")); - } - let cap = self - .channel_capacity - .lock() - .unwrap() - .ok_or_else(|| anyhow!("no capacity set"))?; - Ok(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: cap, - }) - } - } - - #[async_trait] - impl LightningProvider for MockApi { - async fn fund_jit_channel( - &self, - _peer_id: &PublicKey, - _amount: &Msat, - ) -> AnyResult<(Sha256, String)> { - if *self.fund_error.lock().unwrap() { - return Err(anyhow!("fund error")); - } - self.fund_result - .lock() - .unwrap() - .clone() - .ok_or_else(|| anyhow!("no fund result set")) - } - - async fn is_channel_ready( - &self, - _peer_id: &PublicKey, - _channel_id: &Sha256, - ) -> AnyResult { - self.channel_ready_checks.fetch_add(1, Ordering::SeqCst); - Ok(*self.channel_ready.lock().unwrap()) - } - } - - fn handler(api: MockApi) -> HtlcAcceptedHookHandler { - HtlcAcceptedHookHandler { - api, - htlc_minimum_msat: 1_000, - backoff_listpeerchannels: Duration::from_millis(1), // Fast for tests - } - } - - #[tokio::test] - async fn continues_when_scid_not_found() { - let api = MockApi::new().with_no_buy_request(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert_eq!(result, HtlcDecision::NotOurs); - } - - #[tokio::test] - async fn continues_when_mpp_payment() { - let entry = test_datastore_entry(Some(Msat(50_000_000))); // MPP = has expected size - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert_eq!( - result, - HtlcDecision::Reject { - reason: RejectReason::MppNotSupported - } - ); - } - - #[tokio::test] - async fn fails_when_offer_expired() { - let mut entry = test_datastore_entry(None); - entry.opening_fee_params = expired_opening_fee_params(); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api.clone()); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::OfferExpired { .. } - } - )); - assert_eq!(api.del_call_count(), 1); // Should delete expired entry - } - - #[tokio::test] - async fn fails_when_amount_below_min_fee() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - // min_fee_msat is 2_000 - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(1_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::AmountBelowMinimum { .. } - } - )); - } - - #[tokio::test] - async fn fails_when_amount_above_max() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - // max_payment_size_msat is 100_000_000 - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(200_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::AmountAboveMaximum { .. } - } - )); - } - - #[tokio::test] - async fn fails_when_amount_doesnt_cover_fee_plus_minimum() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - // min_fee = 2_000, htlc_minimum = 1_000 - // Amount must be > fee + htlc_minimum - // At 3_000: fee ~= 2_000 + (3_000 * 10_000 / 1_000_000) = 2_030 - // 2_030 + 1_000 = 3_030 > 3_000, so should fail - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(3_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::InsufficientForFee { .. } - } - )); - } - - #[tokio::test] - async fn fails_when_fee_computation_overflows() { - let mut entry = test_datastore_entry(None); - entry.opening_fee_params.min_fee_msat = Msat(u64::MAX / 2); - entry.opening_fee_params.proportional = Ppm(u32::MAX); - entry.opening_fee_params.min_payment_size_msat = Msat(1); - entry.opening_fee_params.max_payment_size_msat = Msat(u64::MAX); - - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(u64::MAX / 2, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::FeeOverflow, - } - )); - } - - #[tokio::test] - async fn fails_when_channel_capacity_errors() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity_error(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.expect_err("should fail"); - - assert!(matches!(result, HtlcError::CapacityQuery(_))); - } - - #[tokio::test] - async fn fails_when_policy_denies_channel() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry).with_channel_denied(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::PolicyDenied, - } - )); - } - - #[tokio::test] - async fn fails_when_fund_channel_errors() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_error(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.expect_err("should fail"); - - assert!(matches!(result, HtlcError::FundChannel(_))); - } - - #[tokio::test] - async fn success_flow_continues_with_modified_payload() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api.clone()); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - let HtlcDecision::Forward { - payload, - forward_to, - extra_tlvs, - } = result - else { - panic!("expected forward, got {:?}", result) - }; - - assert_eq!(forward_to, test_channel_id()); - assert!(!payload.0.is_empty()); - assert!(!extra_tlvs.0.is_empty()); - } - - #[tokio::test] - async fn polls_until_channel_ready() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(false); - - let h = handler(api.clone()); - - // Spawn handler, will block on channel ready - let handle = tokio::spawn(async move { - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - result - }); - - // Let it poll a few times - tokio::time::sleep(Duration::from_millis(10)).await; - assert!(api.channel_ready_check_count() > 1); - - // Now make channel ready - *api.channel_ready.lock().unwrap() = true; - - let result = handle.await.unwrap(); - assert!(matches!(result, HtlcDecision::Forward { .. })); - } - - #[tokio::test] - async fn deducts_fee_from_forward_amount() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - let HtlcDecision::Forward { payload, .. } = result else { - panic!("expected forward, got {:?}", result) - }; - - // Verify payload contains deducted amount - // fee = max(min_fee, amount * proportional / 1_000_000) - // fee = max(2_000, 10_000_000 * 10_000 / 1_000_000) = max(2_000, 100_000) = 100_000 - // deducted = 10_000_000 - 100_000 = 9_900_000 - let forward_amt = payload.get_tu64(TLV_FORWARD_AMT).unwrap(); - assert_eq!(forward_amt, Some(9_900_000)); - } - - #[tokio::test] - async fn extra_tlvs_contain_opening_fee() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - let HtlcDecision::Forward { extra_tlvs, .. } = result else { - panic!("expected forward, got {:?}", result) - }; - - // Opening fee should be in TLV 65537 - let opening_fee = extra_tlvs.get_u64(65537).unwrap(); - assert_eq!(opening_fee, Some(100_000)); // Same fee calculation as above - } - - #[tokio::test] - async fn handles_minimum_valid_amount() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - // Just enough to cover fee + htlc_minimum - // fee at 1_000_000 = max(2_000, 1_000_000 * 10_000 / 1_000_000) = max(2_000, 10_000) = 10_000 - // Need: fee + htlc_minimum < amount - // 10_000 + 1_000 = 11_000 < 1_000_000 ✓ - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(1_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!(result, HtlcDecision::Forward { .. })); - } - - #[tokio::test] - async fn handles_maximum_valid_amount() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(200_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - // max_payment_size_msat is 100_000_000 - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(100_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!(result, HtlcDecision::Forward { .. })); - } -} diff --git a/plugins/lsps-plugin/src/core/lsps2/manager.rs b/plugins/lsps-plugin/src/core/lsps2/manager.rs new file mode 100644 index 000000000000..58aa64d00fe1 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/manager.rs @@ -0,0 +1,990 @@ +use super::actor::{ActionExecutor, ActorInboxHandle, HtlcResponse}; +use super::provider::{DatastoreProvider, ForwardActivity, RecoveryProvider}; +use super::session::{PaymentPart, Session}; +use crate::core::lsps2::actor::SessionActor; +use crate::core::lsps2::event_sink::EventSink; +use crate::proto::lsps0::ShortChannelId; +use crate::proto::lsps2::{DatastoreEntry, SessionOutcome}; +pub use bitcoin::hashes::sha256::Hash as PaymentHash; +use chrono::Utc; +use log::{debug, warn}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Debug, thiserror::Error)] +pub enum ManagerError { + #[error("session terminated")] + SessionTerminated, + #[error("datastore lookup failed: {0}")] + DatastoreLookup(#[source] anyhow::Error), +} + +pub struct SessionConfig { + pub max_parts: usize, + pub collect_timeout_secs: u64, +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + max_parts: 30, // Core-Lightning default. + collect_timeout_secs: 90, // Blip52 default. + } + } +} + +pub struct SessionManager { + sessions: Mutex>, + datastore: Arc, + executor: Arc, + config: SessionConfig, + event_sink: Arc, +} + +impl + SessionManager +{ + pub fn new(datastore: Arc, executor: Arc, config: SessionConfig, event_sink: Arc) -> Self { + Self { + sessions: Mutex::new(HashMap::new()), + datastore, + executor, + config, + event_sink, + } + } + + pub async fn recover(&self, recovery: Arc) -> anyhow::Result<()> { + let entries = self.datastore.list_active_sessions().await?; + + for (scid, entry) in entries { + let payment_hash = entry.payment_hash.as_deref().and_then(|s| s.parse::().ok()); + if let Some(handle) = self.recover_session(scid, entry, &recovery).await? { + if let Some(hash) = payment_hash { + self.sessions.lock().await.insert(hash, handle); + } else { + warn!("recovered session for scid={scid} has no payment_hash, dropping handle"); + } + } + } + + Ok(()) + } + + async fn recover_session( + &self, + scid: ShortChannelId, + entry: DatastoreEntry, + recovery: &Arc, + ) -> anyhow::Result> { + let (channel_id, funding_psbt) = match (&entry.channel_id, &entry.funding_psbt) { + (None, _) => { + if entry.opening_fee_params.valid_until < Utc::now() { + self.datastore + .finalize_session(&scid, SessionOutcome::Timeout) + .await?; + } + return Ok(None); + } + (Some(cid), Some(psbt)) => (cid.clone(), psbt.clone()), + _ => { + warn!("inconsistent datastore entry for scid={scid}, finalizing as Failed"); + self.datastore + .finalize_session(&scid, SessionOutcome::Failed) + .await?; + return Ok(None); + } + }; + + let info = recovery.get_channel_recovery_info(&channel_id).await?; + if !info.exists { + self.datastore + .finalize_session(&scid, SessionOutcome::Abandoned) + .await?; + return Ok(None); + } + + let activity = recovery.get_forward_activity(&channel_id).await?; + + match activity { + ForwardActivity::NoForwards => { + recovery + .close_and_unreserve(&channel_id, &funding_psbt) + .await?; + let mut entry = entry; + entry.channel_id = None; + entry.funding_psbt = None; + entry.funding_txid = None; + self.datastore.save_session(&scid, &entry).await?; + Ok(None) + } + ForwardActivity::AllFailed => { + self.datastore + .finalize_session(&scid, SessionOutcome::Abandoned) + .await?; + Ok(None) + } + ForwardActivity::Offered => { + let (session, initial_actions) = Session::recover( + channel_id.clone(), + funding_psbt.clone(), + None, + entry.opening_fee_params.clone(), + ); + + let handle = SessionActor::spawn_recovered_session_actor( + session, + entry, + initial_actions, + self.executor.clone(), + scid, + self.datastore.clone(), + self.event_sink.clone(), + ); + + Ok(Some(handle)) + } + ForwardActivity::Settled => { + // Forwards already settled — recover into Broadcasting state + // so the actor self-drives via BroadcastFundingTx without + // needing a forward_event notification from CLN. + let preimage = entry.preimage.clone().unwrap_or_default(); + let (session, initial_actions) = Session::recover( + channel_id.clone(), + funding_psbt.clone(), + Some(preimage), + entry.opening_fee_params.clone(), + ); + + let handle = SessionActor::spawn_recovered_session_actor( + session, + entry, + initial_actions, + self.executor.clone(), + scid, + self.datastore.clone(), + self.event_sink.clone(), + ); + + Ok(Some(handle)) + } + } + } + + pub async fn on_part( + &self, + payment_hash: PaymentHash, + scid: ShortChannelId, + part: PaymentPart, + ) -> Result { + let handle = { + let mut sessions = self.sessions.lock().await; + if let Some(handle) = sessions.get(&payment_hash) { + handle.clone() + } else { + let handle = self.create_session(&scid, &payment_hash).await?; + sessions.insert(payment_hash, handle.clone()); + handle + } + }; + + match handle.add_part(part).await { + Ok(resp) => Ok(resp), + Err(_) => { + self.sessions.lock().await.remove(&payment_hash); + Err(ManagerError::SessionTerminated) + } + } + } + + pub async fn on_payment_settled( + &self, + payment_hash: PaymentHash, + preimage: Option, + updated_index: Option, + ) -> Result<(), ManagerError> { + let handle = { + let mut sessions = self.sessions.lock().await; + match sessions.remove(&payment_hash) { + Some(handle) => handle, + None => { + debug!("on_payment_settled: no session for {payment_hash}"); + return Ok(()); + } + } + }; + + match handle.payment_settled(preimage, updated_index).await { + Ok(()) => Ok(()), + Err(_) => Err(ManagerError::SessionTerminated), + } + } + + pub async fn on_payment_failed( + &self, + payment_hash: PaymentHash, + updated_index: Option, + ) -> Result<(), ManagerError> { + let handle = { + let mut sessions = self.sessions.lock().await; + match sessions.remove(&payment_hash) { + Some(handle) => handle, + None => { + debug!("on_payment_failed: no session for {payment_hash}"); + return Ok(()); + } + } + }; + + match handle.payment_failed(updated_index).await { + Ok(()) => Ok(()), + Err(_) => Err(ManagerError::SessionTerminated), + } + } + + pub async fn on_new_block(&self, height: u32) { + let handles: Vec<(PaymentHash, ActorInboxHandle)> = { + let sessions = self.sessions.lock().await; + sessions.iter().map(|(k, v)| (*k, v.clone())).collect() + }; + + let mut dead = Vec::new(); + for (hash, handle) in handles { + if handle.new_block(height).await.is_err() { + dead.push(hash); + } + } + + if !dead.is_empty() { + let mut sessions = self.sessions.lock().await; + for hash in dead { + sessions.remove(&hash); + } + } + } + + async fn create_session( + &self, + scid: &ShortChannelId, + payment_hash: &PaymentHash, + ) -> Result { + let mut entry = self + .datastore + .get_buy_request(scid) + .await + .map_err(ManagerError::DatastoreLookup)?; + + entry.payment_hash = Some(payment_hash.to_string()); + self.datastore + .save_session(scid, &entry) + .await + .map_err(ManagerError::DatastoreLookup)?; + + let peer_id = entry.peer_id.to_string(); + let session = Session::new( + self.config.max_parts, + entry.opening_fee_params.clone(), + entry.expected_payment_size, + entry.channel_capacity_msat, + peer_id.clone(), + ); + + Ok(SessionActor::spawn_session_actor( + session, + entry, + self.executor.clone(), + peer_id, + self.config.collect_timeout_secs, + *scid, + self.datastore.clone(), + self.event_sink.clone(), + )) + } + + #[cfg(test)] + async fn session_count(&self) -> usize { + self.sessions.lock().await.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::lsps2::event_sink::NoopEventSink; + use crate::core::lsps2::provider::{ChannelRecoveryInfo, ForwardActivity, RecoveryProvider}; + use crate::proto::lsps0::{Msat, Ppm}; + use crate::proto::lsps2::{DatastoreEntry, OpeningFeeParams, Promise, SessionOutcome}; + use async_trait::async_trait; + use bitcoin::hashes::Hash; + use chrono::{Duration as ChronoDuration, Utc}; + use std::time::Duration; + + fn test_payment_hash(byte: u8) -> PaymentHash { + PaymentHash::from_byte_array([byte; 32]) + } + + fn test_scid() -> ShortChannelId { + ShortChannelId::from(100u64 << 40 | 1u64 << 16) + } + + fn test_scid_2() -> ShortChannelId { + ShortChannelId::from(200u64 << 40 | 2u64 << 16) + } + + fn unknown_scid() -> ShortChannelId { + ShortChannelId::from(999u64 << 40 | 9u64 << 16 | 9) + } + + fn test_peer_id() -> bitcoin::secp256k1::PublicKey { + "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798" + .parse() + .unwrap() + } + + fn opening_fee_params(min_fee_msat: u64) -> OpeningFeeParams { + OpeningFeeParams { + min_fee_msat: Msat::from_msat(min_fee_msat), + proportional: Ppm::from_ppm(1_000), + valid_until: Utc::now() + ChronoDuration::hours(1), + min_lifetime: 144, + max_client_to_self_delay: 2016, + min_payment_size_msat: Msat::from_msat(1), + max_payment_size_msat: Msat::from_msat(u64::MAX), + promise: Promise("test-promise".to_owned()), + } + } + + fn test_datastore_entry() -> DatastoreEntry { + DatastoreEntry { + peer_id: test_peer_id(), + opening_fee_params: opening_fee_params(1), + expected_payment_size: Some(Msat::from_msat(1_000)), + channel_capacity_msat: Msat::from_msat(100_000_000), + created_at: Utc::now(), + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, + forwards_updated_index: None, + payment_hash: None, + } + } + + fn part(htlc_id: u64, amount_msat: u64) -> PaymentPart { + PaymentPart { + htlc_id, + amount_msat: Msat::from_msat(amount_msat), + cltv_expiry: 100, + } + } + + struct MockDatastore { + entries: HashMap, + } + + impl MockDatastore { + fn new() -> Self { + let mut entries = HashMap::new(); + entries.insert(test_scid().to_string(), test_datastore_entry()); + entries.insert(test_scid_2().to_string(), test_datastore_entry()); + Self { entries } + } + } + + #[async_trait] + impl DatastoreProvider for MockDatastore { + async fn store_buy_request( + &self, + scid: &ShortChannelId, + _peer_id: &bitcoin::secp256k1::PublicKey, + _offer: &OpeningFeeParams, + _expected_payment_size: &Option, + _channel_capacity_msat: &Msat, + ) -> anyhow::Result { + self.get_buy_request(scid).await + } + + async fn get_buy_request(&self, scid: &ShortChannelId) -> anyhow::Result { + self.entries + .get(&scid.to_string()) + .cloned() + .ok_or_else(|| anyhow::anyhow!("not found: {scid}")) + } + + async fn save_session( + &self, + _scid: &ShortChannelId, + _entry: &DatastoreEntry, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn finalize_session( + &self, + _scid: &ShortChannelId, + _outcome: SessionOutcome, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn list_active_sessions(&self) -> anyhow::Result> { + Ok(self.entries.iter().map(|(k, v)| { + (k.parse::().unwrap(), v.clone()) + }).collect()) + } + } + + struct MockExecutor { + fund_succeeds: bool, + } + + #[async_trait] + impl ActionExecutor for MockExecutor { + async fn fund_channel( + &self, + _peer_id: String, + _channel_capacity_msat: Msat, + _opening_fee_params: OpeningFeeParams, + _scid: ShortChannelId, + ) -> anyhow::Result<(String, String)> { + if self.fund_succeeds { + Ok(("channel-id-1".to_string(), "psbt-1".to_string())) + } else { + Err(anyhow::anyhow!("fund error")) + } + } + + async fn broadcast_tx( + &self, + _channel_id: String, + _funding_psbt: String, + ) -> anyhow::Result { + Ok("mock-txid".to_string()) + } + + async fn abandon_session( + &self, + _channel_id: String, + _funding_psbt: String, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn disconnect(&self, _peer_id: String) -> anyhow::Result<()> { + Ok(()) + } + + async fn is_channel_alive(&self, _channel_id: &str) -> anyhow::Result { + Ok(true) + } + } + + fn test_manager(fund_succeeds: bool) -> Arc> { + Arc::new(SessionManager::new( + Arc::new(MockDatastore::new()), + Arc::new(MockExecutor { fund_succeeds }), + SessionConfig { + max_parts: 3, + ..SessionConfig::default() + }, + Arc::new(NoopEventSink), + )) + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn first_part_creates_session() { + let mgr = test_manager(true); + + let resp = mgr + .on_part(test_payment_hash(1), test_scid(), part(1, 1_000)) + .await + .unwrap(); + + assert!(matches!(resp, HtlcResponse::Forward { .. })); + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn second_part_routes_to_existing() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // First part reaches threshold (expected=1000) and gets Forward. + let resp1 = mgr + .on_part(hash, test_scid(), part(1, 1_000)) + .await + .unwrap(); + assert!(matches!(resp1, HtlcResponse::Forward { .. })); + + // Session is now in AwaitingSettlement. Second part is forwarded immediately. + let resp2 = mgr.on_part(hash, test_scid(), part(2, 500)).await.unwrap(); + match resp2 { + HtlcResponse::Forward { fee_msat, .. } => { + assert_eq!(fee_msat, 0, "late-arriving part should have zero fee"); + } + other => panic!("expected Forward, got {other:?}"), + } + + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn different_hashes_create_separate_sessions() { + let mgr = test_manager(true); + + let r1 = mgr + .on_part(test_payment_hash(1), test_scid(), part(1, 1_000)) + .await + .unwrap(); + let r2 = mgr + .on_part(test_payment_hash(2), test_scid_2(), part(2, 1_000)) + .await + .unwrap(); + + assert!(matches!(r1, HtlcResponse::Forward { .. })); + assert!(matches!(r2, HtlcResponse::Forward { .. })); + assert_eq!(mgr.session_count().await, 2); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn terminated_session_cleaned_up() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // First on_part with partial amount — won't reach threshold, blocks. + let mgr2 = mgr.clone(); + let h1 = tokio::spawn(async move { mgr2.on_part(hash, test_scid(), part(1, 500)).await }); + + // Advance past 90s collect timeout. + tokio::time::sleep(Duration::from_secs(91)).await; + + // First part should have received Fail from timeout. + let resp = h1.await.unwrap().unwrap(); + assert!(matches!(resp, HtlcResponse::Fail { .. })); + + // Stale entry still in the map. + assert_eq!(mgr.session_count().await, 1); + + // Next on_part detects dead session and cleans up. + let err = mgr + .on_part(hash, test_scid(), part(2, 500)) + .await + .unwrap_err(); + assert!(matches!(err, ManagerError::SessionTerminated { .. })); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn datastore_lookup_failure() { + let mgr = test_manager(true); + + let err = mgr + .on_part(test_payment_hash(1), unknown_scid(), part(1, 1_000)) + .await + .unwrap_err(); + + assert!(matches!(err, ManagerError::DatastoreLookup { .. })); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_settled_unknown_hash_is_ok() { + let mgr = test_manager(true); + let result = mgr.on_payment_settled(test_payment_hash(99), None, None).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_settled_active_session() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Create session and forward payment. + let resp = mgr + .on_part(hash, test_scid(), part(1, 1_000)) + .await + .unwrap(); + assert!(matches!(resp, HtlcResponse::Forward { .. })); + + // Settle payment — session is in AwaitingSettlement. + let result = mgr.on_payment_settled(hash, None, None).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_settled_stale_session_cleaned_up() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Create a session with a partial amount — won't reach threshold. + let mgr2 = mgr.clone(); + let h1 = tokio::spawn(async move { mgr2.on_part(hash, test_scid(), part(1, 500)).await }); + + // Advance past 90s collect timeout → actor dies. + tokio::time::sleep(Duration::from_secs(91)).await; + let resp = h1.await.unwrap().unwrap(); + assert!(matches!(resp, HtlcResponse::Fail { .. })); + + // Stale entry remains. + assert_eq!(mgr.session_count().await, 1); + + // on_payment_settled hits dead handle → removes entry. + let err = mgr.on_payment_settled(hash, None, None).await.unwrap_err(); + assert!(matches!(err, ManagerError::SessionTerminated { .. })); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_failed_unknown_hash_is_ok() { + let mgr = test_manager(true); + let result = mgr.on_payment_failed(test_payment_hash(99), None).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_failed_active_session() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Create session and forward payment. + let resp = mgr + .on_part(hash, test_scid(), part(1, 1_000)) + .await + .unwrap(); + assert!(matches!(resp, HtlcResponse::Forward { .. })); + + // Fail payment — session is in AwaitingSettlement. + let result = mgr.on_payment_failed(hash, None).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn concurrent_first_parts_same_hash() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Two concurrent on_part calls for the same hash. + // expected_payment_size=1000, so two 500-msat parts reach threshold together. + let mgr2 = mgr.clone(); + let h1 = tokio::spawn(async move { mgr2.on_part(hash, test_scid(), part(1, 500)).await }); + let mgr3 = mgr.clone(); + let h2 = tokio::spawn(async move { mgr3.on_part(hash, test_scid(), part(2, 500)).await }); + + let r1 = h1.await.unwrap().unwrap(); + let r2 = h2.await.unwrap().unwrap(); + + assert!(matches!(r1, HtlcResponse::Forward { .. })); + assert!(matches!(r2, HtlcResponse::Forward { .. })); + assert_eq!(mgr.session_count().await, 1); + } + + struct MockRecoveryProvider { + channel_exists: bool, + forward_activity: ForwardActivity, + } + + impl Default for MockRecoveryProvider { + fn default() -> Self { + Self { + channel_exists: false, + forward_activity: ForwardActivity::NoForwards, + } + } + } + + #[async_trait] + impl RecoveryProvider for MockRecoveryProvider { + async fn get_forward_activity( + &self, + _channel_id: &str, + ) -> anyhow::Result { + Ok(self.forward_activity.clone()) + } + async fn get_channel_recovery_info( + &self, + _channel_id: &str, + ) -> anyhow::Result { + Ok(ChannelRecoveryInfo { + exists: self.channel_exists, + withheld: true, + }) + } + async fn close_and_unreserve( + &self, + _channel_id: &str, + _funding_psbt: &str, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + #[tokio::test] + async fn recover_pre_funding_expired_finalizes_as_timeout() { + let mut ds = MockDatastore::new(); + // Clear default entries, add one with expired opening_fee_params. + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.opening_fee_params.valid_until = Utc::now() - ChronoDuration::hours(1); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + mgr.recover(Arc::new(MockRecoveryProvider::default())) + .await + .unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_pre_funding_valid_leaves_session_for_replay() { + let ds = MockDatastore::new(); // entries have valid_until in future + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + mgr.recover(Arc::new(MockRecoveryProvider::default())).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + + // Replayed HTLC should still create a fresh session + let _response = mgr.on_part( + test_payment_hash(1), + test_scid(), + part(1, 1_000), + ).await.unwrap(); + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test] + async fn recover_funded_channel_gone_finalizes_abandoned() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-gone".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: false, + forward_activity: ForwardActivity::NoForwards, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_no_forwards_resets_session() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::NoForwards, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_all_failed_finalizes_abandoned() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::AllFailed, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_offered_registers_in_sessions() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + + // Recovered session must be reachable via on_payment_settled. + assert_eq!(mgr.session_count().await, 1); + let result = mgr.on_payment_settled(test_payment_hash(1), None, None).await; + assert!(result.is_ok()); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_offered_reachable_by_on_payment_failed() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 1); + + let result = mgr.on_payment_failed(test_payment_hash(1), None).await; + assert!(result.is_ok()); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_no_payment_hash_not_registered() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = None; // No payment_hash + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_settled_registers_in_sessions() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Settled, + }); + + mgr.recover(recovery).await.unwrap(); + + // Settled sessions should still be registered (actor will receive + // BroadcastFundingTx as initial action and self-drive to completion). + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn recovered_actor_settles_via_inbox() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry.clone()); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 1); + + // Simulate forward_event delivering settlement. + let result = mgr.on_payment_settled(test_payment_hash(1), Some("preimage123".to_string()), Some(1)).await; + assert!(result.is_ok()); + assert_eq!(mgr.session_count().await, 0); + + // Give the actor time to finalize. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } +} diff --git a/plugins/lsps-plugin/src/core/lsps2/mod.rs b/plugins/lsps-plugin/src/core/lsps2/mod.rs index 18bf1cb51ce1..eadfdadc890d 100644 --- a/plugins/lsps-plugin/src/core/lsps2/mod.rs +++ b/plugins/lsps-plugin/src/core/lsps2/mod.rs @@ -1,3 +1,6 @@ -pub mod htlc; +pub mod actor; +pub mod event_sink; +pub mod manager; pub mod provider; pub mod service; +pub mod session; diff --git a/plugins/lsps-plugin/src/core/lsps2/provider.rs b/plugins/lsps-plugin/src/core/lsps2/provider.rs index 6466630a4748..45be1bfef1ba 100644 --- a/plugins/lsps-plugin/src/core/lsps2/provider.rs +++ b/plugins/lsps-plugin/src/core/lsps2/provider.rs @@ -1,24 +1,15 @@ use anyhow::Result; use async_trait::async_trait; -use bitcoin::hashes::sha256::Hash; use bitcoin::secp256k1::PublicKey; use crate::proto::{ lsps0::{Msat, ShortChannelId}, lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, + DatastoreEntry, Lsps2PolicyBuyRequest, Lsps2PolicyBuyResponse, Lsps2PolicyGetInfoRequest, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, SessionOutcome, }, }; -pub type Blockheight = u32; - -#[async_trait] -pub trait BlockheightProvider: Send + Sync { - async fn get_blockheight(&self) -> Result; -} - #[async_trait] pub trait DatastoreProvider: Send + Sync { async fn store_buy_request( @@ -27,27 +18,61 @@ pub trait DatastoreProvider: Send + Sync { peer_id: &PublicKey, offer: &OpeningFeeParams, expected_payment_size: &Option, - ) -> Result; + channel_capacity_msat: &Msat, + ) -> Result; async fn get_buy_request(&self, scid: &ShortChannelId) -> Result; - async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()>; + + async fn save_session(&self, scid: &ShortChannelId, entry: &DatastoreEntry) -> Result<()>; + + async fn finalize_session(&self, scid: &ShortChannelId, outcome: SessionOutcome) -> Result<()>; + + async fn list_active_sessions(&self) -> Result>; } +/// Status of forwards on a channel, used during recovery classification. +#[derive(Debug, Clone, PartialEq)] +pub enum ForwardActivity { + /// No forwards ever happened on this channel. + NoForwards, + /// All forwards failed (none settled or offered). + AllFailed, + /// Some forwards are in-flight (OFFERED) but none have settled yet. + Offered, + /// At least one forward has settled. + Settled, +} + +/// Information about a channel needed for recovery classification. +#[derive(Debug, Clone)] +pub struct ChannelRecoveryInfo { + pub exists: bool, + pub withheld: bool, +} + +/// Provides recovery-specific queries. Separated from ActionExecutor +/// to keep the normal operation interface clean. #[async_trait] -pub trait LightningProvider: Send + Sync { - async fn fund_jit_channel(&self, peer_id: &PublicKey, amount: &Msat) -> Result<(Hash, String)>; - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Hash) -> Result; +pub trait RecoveryProvider: Send + Sync { + /// Check forward activity on a channel using both in-flight HTLCs + /// and historical forwards. + async fn get_forward_activity(&self, channel_id: &str) -> Result; + + /// Get channel recovery info (exists, withheld status). + async fn get_channel_recovery_info(&self, channel_id: &str) -> Result; + + /// Close a channel and unreserve its inputs. + async fn close_and_unreserve(&self, channel_id: &str, funding_psbt: &str) -> Result<()>; } #[async_trait] -pub trait Lsps2OfferProvider: Send + Sync { - async fn get_offer( +pub trait Lsps2PolicyProvider: Send + Sync { + async fn get_blockheight(&self) -> Result; + + async fn get_info( &self, request: &Lsps2PolicyGetInfoRequest, ) -> Result; - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> Result; + async fn buy(&self, request: &Lsps2PolicyBuyRequest) -> Result; } diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs index a3ab32406e71..8b980a35b1a7 100644 --- a/plugins/lsps-plugin/src/core/lsps2/service.rs +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -1,6 +1,6 @@ use crate::{ core::{ - lsps2::provider::{BlockheightProvider, DatastoreProvider, Lsps2OfferProvider}, + lsps2::provider::{DatastoreProvider, Lsps2PolicyProvider}, router::JsonRpcRouterBuilder, server::LspsProtocol, }, @@ -9,7 +9,8 @@ use crate::{ lsps0::{LSPS0RpcErrorExt as _, ShortChannelId}, lsps2::{ Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, Lsps2GetInfoResponse, - Lsps2PolicyGetInfoRequest, OpeningFeeParams, ShortChannelIdJITExt, + Lsps2PolicyBuyRequest, Lsps2PolicyGetInfoRequest, OpeningFeeParams, + ShortChannelIdJITExt, }, }, register_handler, @@ -48,31 +49,39 @@ where } } -pub struct Lsps2ServiceHandler { - pub api: Arc, +pub struct Lsps2ServiceHandler { + pub datastore: Arc, + pub policy: Arc

, pub promise_secret: [u8; 32], } -impl Lsps2ServiceHandler { - pub fn new(api: Arc, promise_seret: &[u8; 32]) -> Self { +impl Lsps2ServiceHandler { + pub fn new( + datastore: Arc, + policy: Arc

, + promise_secret: &[u8; 32], + ) -> Self { Lsps2ServiceHandler { - api, - promise_secret: promise_seret.to_owned(), + datastore, + policy, + promise_secret: promise_secret.to_owned(), } } } #[async_trait] -impl Lsps2Handler - for Lsps2ServiceHandler +impl Lsps2Handler for Lsps2ServiceHandler +where + D: DatastoreProvider + 'static, + P: Lsps2PolicyProvider + 'static, { async fn handle_get_info( &self, request: Lsps2GetInfoRequest, ) -> std::result::Result { let res_data = self - .api - .get_offer(&Lsps2PolicyGetInfoRequest { + .policy + .get_info(&Lsps2PolicyGetInfoRequest { token: request.token.clone(), }) .await @@ -107,7 +116,7 @@ impl // Generate a tmp scid to identify jit channel request in htlc. let blockheight = self - .api + .policy .get_blockheight() .await .map_err(|_| RpcError::internal_error("internal error"))?; @@ -116,15 +125,30 @@ impl // already handed out -> Check datastore entries. let jit_scid = ShortChannelId::generate_jit(blockheight, 12); // Approximately 2 hours in the future. - let ok = self - .api - .store_buy_request(&jit_scid, &peer_id, &fee_params, &request.payment_size_msat) + let ch_cap_res = self + .policy + .buy(&Lsps2PolicyBuyRequest { + opening_fee_params: fee_params.clone(), + payment_size_msat: request.payment_size_msat, + }) .await .map_err(|_| RpcError::internal_error("internal error"))?; - if !ok { - return Err(RpcError::internal_error("internal error"))?; - } + let channel_capacity_msat = ch_cap_res + .channel_capacity_msat + .ok_or_else(|| RpcError::internal_error("channel capacity denied by policy"))?; + + let _entry = self + .datastore + .store_buy_request( + &jit_scid, + &peer_id, + &fee_params, + &request.payment_size_msat, + &channel_capacity_msat, + ) + .await + .map_err(|_| RpcError::internal_error("internal error"))?; Ok(Lsps2BuyResponse { jit_channel_scid: jit_scid, @@ -142,9 +166,9 @@ mod tests { use super::*; use crate::proto::lsps0::{Msat, Ppm}; use crate::proto::lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoResponse, OpeningFeeParams, - PolicyOpeningFeeParams, Promise, + DatastoreEntry, Lsps2PolicyBuyRequest, Lsps2PolicyBuyResponse, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, PolicyOpeningFeeParams, Promise, + SessionOutcome, }; use anyhow::{anyhow, Result as AnyResult}; use chrono::{TimeZone, Utc}; @@ -188,11 +212,13 @@ mod tests { offer_response: Arc>>, blockheight: Arc>>, store_result: Arc>>, + buy_response: Arc>>>, // Errors offer_error: Arc>, blockheight_error: Arc>, store_error: Arc>, + buy_error: Arc>, // Capture calls stored_requests: Arc>>, @@ -254,14 +280,34 @@ mod tests { self } + fn with_buy_capacity(self, capacity_msat: u64) -> Self { + *self.buy_response.lock().unwrap() = Some(Some(Msat::from_msat(capacity_msat))); + self + } + + fn with_buy_error(self) -> Self { + *self.buy_error.lock().unwrap() = true; + self + } + fn stored_requests(&self) -> Vec { self.stored_requests.lock().unwrap().clone() } } #[async_trait] - impl Lsps2OfferProvider for MockApi { - async fn get_offer( + impl Lsps2PolicyProvider for MockApi { + async fn get_blockheight(&self) -> AnyResult { + if *self.blockheight_error.lock().unwrap() { + return Err(anyhow!("blockheight error")); + } + self.blockheight + .lock() + .unwrap() + .ok_or_else(|| anyhow!("no blockheight set")) + } + + async fn get_info( &self, _request: &Lsps2PolicyGetInfoRequest, ) -> AnyResult { @@ -275,24 +321,21 @@ mod tests { .ok_or_else(|| anyhow!("no offer response set")) } - async fn get_channel_capacity( + async fn buy( &self, - _params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - unimplemented!("not needed for service tests") - } - } - - #[async_trait] - impl BlockheightProvider for MockApi { - async fn get_blockheight(&self) -> AnyResult { - if *self.blockheight_error.lock().unwrap() { - return Err(anyhow!("blockheight error")); + _request: &Lsps2PolicyBuyRequest, + ) -> AnyResult { + if *self.buy_error.lock().unwrap() { + return Err(anyhow!("buy error")); } - self.blockheight + let cap = self + .buy_response .lock() .unwrap() - .ok_or_else(|| anyhow!("no blockheight set")) + .ok_or_else(|| anyhow!("no buy response set"))?; + Ok(Lsps2PolicyBuyResponse { + channel_capacity_msat: cap, + }) } } @@ -304,7 +347,8 @@ mod tests { peer_id: &PublicKey, _fee_params: &OpeningFeeParams, payment_size: &Option, - ) -> AnyResult { + _channel_capacity_msat: &Msat, + ) -> AnyResult { if *self.store_error.lock().unwrap() { return Err(anyhow!("store error")); } @@ -314,20 +358,62 @@ mod tests { payment_size: *payment_size, }); - Ok(self.store_result.lock().unwrap().unwrap_or(true)) + if !self.store_result.lock().unwrap().unwrap_or(true) { + return Err(anyhow!("duplicate SCID")); + } + + Ok(DatastoreEntry { + peer_id: *peer_id, + opening_fee_params: OpeningFeeParams { + min_fee_msat: Msat(0), + proportional: Ppm(0), + valid_until: Utc::now(), + min_lifetime: 0, + max_client_to_self_delay: 0, + min_payment_size_msat: Msat(0), + max_payment_size_msat: Msat(0), + promise: Promise(String::new()), + }, + expected_payment_size: *payment_size, + channel_capacity_msat: Msat(0), + created_at: Utc::now(), + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, + forwards_updated_index: None, + payment_hash: None, + }) } async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { unimplemented!("not needed for service tests") } - async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { + async fn save_session( + &self, + _scid: &ShortChannelId, + _entry: &DatastoreEntry, + ) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } + + async fn finalize_session( + &self, + _scid: &ShortChannelId, + _outcome: SessionOutcome, + ) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } + + async fn list_active_sessions(&self) -> AnyResult> { unimplemented!("not needed for service tests") } } - fn handler(api: MockApi) -> Lsps2ServiceHandler { - Lsps2ServiceHandler::new(Arc::new(api), &test_secret()) + fn handler(api: MockApi) -> Lsps2ServiceHandler { + let api = Arc::new(api); + Lsps2ServiceHandler::new(api.clone(), api, &test_secret()) } #[tokio::test] @@ -408,7 +494,8 @@ mod tests { async fn buy_success_with_payment_size() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(true); + .with_store_result(true) + .with_buy_capacity(100_000_000); let h = handler(api.clone()); let request = Lsps2BuyRequest { @@ -434,7 +521,8 @@ mod tests { async fn buy_success_without_payment_size() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(true); + .with_store_result(true) + .with_buy_capacity(100_000_000); let h = handler(api.clone()); let request = Lsps2BuyRequest { @@ -537,7 +625,7 @@ mod tests { #[tokio::test] async fn buy_handles_blockheight_error() { - let api = MockApi::new().with_blockheight_error(); + let api = MockApi::new().with_blockheight_error().with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { @@ -553,7 +641,7 @@ mod tests { #[tokio::test] async fn buy_handles_store_error() { - let api = MockApi::new().with_blockheight(800_000).with_store_error(); + let api = MockApi::new().with_blockheight(800_000).with_store_error().with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { @@ -568,10 +656,11 @@ mod tests { } #[tokio::test] - async fn buy_handles_store_returns_false() { + async fn buy_handles_store_duplicate_error() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(false); + .with_store_result(false) + .with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { @@ -589,7 +678,8 @@ mod tests { async fn buy_generates_unique_scids() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(true); + .with_store_result(true) + .with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { diff --git a/plugins/lsps-plugin/src/core/lsps2/session.rs b/plugins/lsps-plugin/src/core/lsps2/session.rs new file mode 100644 index 000000000000..ea191a9f98d7 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/session.rs @@ -0,0 +1,2094 @@ +//! Lsps2 Service FSM + +use crate::proto::{ + lsps0::Msat, + lsps2::{ + compute_opening_fee, + failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, + OpeningFeeParams, SessionOutcome, + }, +}; + +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum Error { + #[error("opening fee computation overflow")] + FeeOverflow, + #[error("invalid state transition")] + InvalidTransition { + state: SessionState, + input: SessionInput, + }, + #[error( + "opening fee {opening_fee_msat} exceeds deductible capacity {deductible_capacity_msat}" + )] + InsufficientDeductibleCapacity { + opening_fee_msat: u64, + deductible_capacity_msat: u128, + }, +} + +type Result = std::result::Result; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PaymentPart { + pub htlc_id: u64, + pub amount_msat: Msat, + pub cltv_expiry: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ForwardPart { + pub htlc_id: u64, + pub fee_msat: u64, + pub forward_msat: u64, +} + +impl From for ForwardPart { + fn from(part: PaymentPart) -> Self { + Self { + htlc_id: part.htlc_id, + fee_msat: 0, + forward_msat: part.amount_msat.msat(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionInput { + /// Htlc intercepted + AddPart { part: PaymentPart }, + /// Timeout waiting for parts to arrive from blip052: defaults to 90s. + CollectTimeout, + /// Channel funding failed. + FundingFailed, + /// Zero-conf channel funded, withheld, and ready. + ChannelReady { + channel_id: String, + funding_psbt: String, + }, + /// The initial payment was successfull + PaymentSettled, + /// The inital payment failed + PaymentFailed, + /// Funding tx was broadcasted + FundingBroadcasted, + /// A new block has been mined. + NewBlock { height: u32 }, + /// The JIT channel has been closed or is no longer in CHANNELD_NORMAL. + ChannelClosed { channel_id: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionAction { + FailHtlcs { + failure_code: &'static str, + }, + ForwardHtlcs { + parts: Vec, + channel_id: String, + }, + FundChannel { + peer_id: String, + channel_capacity_msat: Msat, + opening_fee_params: OpeningFeeParams, + }, + FailSession, + AbandonSession { + channel_id: String, + funding_psbt: String, + }, + BroadcastFundingTx { + channel_id: String, + funding_psbt: String, + }, + Disconnect, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionEvent { + PaymentPartAdded { + part: PaymentPart, + n_parts: usize, + parts_sum: Msat, + }, + TooManyParts { + n_parts: usize, + }, + PaymentInsufficientForOpeningFee { + opening_fee_msat: u64, + n_parts: usize, + parts_sum: Msat, + }, + CollectTimeout { + n_parts: usize, + parts_sum: Msat, + }, + FundingChannel, + ForwardHtlcs { + channel_id: String, + n_parts: usize, + parts_sum: Msat, + opening_fee_msat: u64, + }, + PaymentSettled { + parts: Vec, + }, + PaymentFailed, + ChannelReady { + channel_id: String, + funding_psbt: String, + }, + FundingBroadcasted { + funding_psbt: String, + }, + SessionFailed, + SessionAbandoned, + SessionSucceeded, + UnsafeHtlcTimeout { + height: u32, + cltv_min: u32, + }, + UnusualInput { + state: String, + input: String, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionState { + Collecting { + parts: Vec, + }, + + /// Channel opened in progress, waiting for `channel_ready`. + AwaitingChannelReady { + parts: Vec, + opening_fee_msat: u64, + }, + + /// HTLCs forwarded, waiting for the client to settle or reject. + AwaitingSettlement { + forwarded_parts: Vec, + forwarded_amount_msat: u64, + deducted_fee_msat: u64, + channel_id: String, + funding_psbt: String, + }, + + /// HTLCs got resolved, broadcasting funding tx. + Broadcasting { + channel_id: String, + funding_psbt: String, + }, + + /// Terminal: session failed before a channel was opened. + Failed, + + /// Terminal: session failed after a channel was opened. + Abandoned, + + /// Terminal: session successfully finished + Succeeded, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct ApplyResult { + pub actions: Vec, + pub events: Vec, +} + +impl ApplyResult { + fn unusual_input(state: &SessionState, input: &SessionInput) -> Self { + Self { + events: vec![SessionEvent::UnusualInput { + state: format!("{:?}", state), + input: format!("{:?}", input), + }], + ..Default::default() + } + } +} + +fn cltv_min(parts: &[PaymentPart]) -> Option { + parts.iter().map(|p| p.cltv_expiry).min() +} + +#[derive(Debug)] +pub struct Session { + state: SessionState, + // from BOLT2 + max_parts: usize, + // From the offer/fee_policy + opening_fee_params: OpeningFeeParams, + payment_size_msat: Option, + channel_capacity_msat: Msat, + peer_id: String, +} + +impl Session { + pub fn new( + max_parts: usize, + opening_fee_params: OpeningFeeParams, + payment_size_msat: Option, + channel_capacity_msat: Msat, + peer_id: String, + ) -> Self { + Self { + state: SessionState::Collecting { parts: vec![] }, + max_parts, + opening_fee_params, + payment_size_msat, + channel_capacity_msat, + peer_id, + } + } + + /// Reconstruct a session from persisted state for crash recovery. + /// + /// Initializes the FSM in the appropriate state based on whether a + /// preimage was already captured: + /// - `preimage: None` → `AwaitingSettlement` (waiting for payment outcome) + /// - `preimage: Some` → `Broadcasting` (payment settled, need to broadcast) + /// + /// Forwarded HTLC parts are not reconstructed — CLN manages those + /// independently. The FSM only needs channel identity to drive + /// remaining actions. + pub fn recover( + channel_id: String, + funding_psbt: String, + preimage: Option, + opening_fee_params: OpeningFeeParams, + ) -> (Self, Vec) { + let (state, actions) = if preimage.is_some() { + ( + SessionState::Broadcasting { + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }, + vec![SessionAction::BroadcastFundingTx { + channel_id, + funding_psbt, + }], + ) + } else { + ( + SessionState::AwaitingSettlement { + forwarded_parts: vec![], + forwarded_amount_msat: 0, + deducted_fee_msat: 0, + channel_id, + funding_psbt, + }, + vec![], + ) + }; + + let session = Self { + state, + max_parts: 0, + opening_fee_params, + payment_size_msat: None, + channel_capacity_msat: Msat::from_msat(0), + peer_id: String::new(), + }; + + (session, actions) + } + + pub fn is_terminal(&self) -> bool { + matches!( + self.state, + SessionState::Failed | SessionState::Abandoned | SessionState::Succeeded + ) + } + + pub fn outcome(&self) -> Option { + match &self.state { + SessionState::Succeeded => Some(SessionOutcome::Succeeded), + SessionState::Abandoned => Some(SessionOutcome::Abandoned), + SessionState::Failed => Some(SessionOutcome::Failed), + _ => None, + } + } + + fn check_cltv_timeout( + &mut self, + parts: &[PaymentPart], + height: u32, + ) -> Option { + let min = cltv_min(parts)?; + if height > min { + self.state = SessionState::Failed; + Some(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ], + events: vec![ + SessionEvent::UnsafeHtlcTimeout { + height, + cltv_min: min, + }, + SessionEvent::SessionFailed, + ], + }) + } else { + None + } + } + + pub fn apply(&mut self, input: SessionInput) -> Result { + match (&mut self.state, input) { + // + // Collecting transitions. + // + (SessionState::Collecting { parts }, SessionInput::AddPart { part }) => { + parts.push(part.clone()); + let n_parts = parts.len(); + let parts_sum = parts.iter().map(|p| p.amount_msat).sum(); + + let mut events = vec![SessionEvent::PaymentPartAdded { + part: part.clone(), + n_parts, + parts_sum, + }]; + + // Variable-amount (None): first HTLC triggers immediately, second fails. + // Fixed-amount (Some): accumulate until threshold, fail if too many parts. + let threshold_reached = match self.payment_size_msat { + None => { + if n_parts > 1 { + self.state = SessionState::Failed; + events.push(SessionEvent::TooManyParts { n_parts }); + events.push(SessionEvent::SessionFailed); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::FailSession, + ], + events, + }); + } + true + } + Some(_) => { + if n_parts > self.max_parts { + self.state = SessionState::Failed; + events.push(SessionEvent::TooManyParts { n_parts }); + events.push(SessionEvent::SessionFailed); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::FailSession, + ], + events, + }); + } + parts_sum >= self.payment_size_msat.unwrap() + } + }; + + if threshold_reached { + let opening_fee_msat = compute_opening_fee( + parts_sum.msat(), + self.opening_fee_params.min_fee_msat.msat(), + self.opening_fee_params.proportional.ppm() as u64, + ) + .ok_or(Error::FeeOverflow)?; + + if opening_fee_msat >= parts_sum.msat() + || !is_deductible(parts, opening_fee_msat) + { + self.state = SessionState::Failed; + events.push(SessionEvent::PaymentInsufficientForOpeningFee { + opening_fee_msat, + n_parts, + parts_sum, + }); + events.push(SessionEvent::SessionFailed); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::FailSession, + ], + events, + }); + } + + // We collected enough parts to fund the channel, transition. + self.state = SessionState::AwaitingChannelReady { + parts: std::mem::take(parts), + opening_fee_msat, + }; + + events.push(SessionEvent::FundingChannel); + + return Ok(ApplyResult { + events, + actions: vec![SessionAction::FundChannel { + peer_id: self.peer_id.clone(), + channel_capacity_msat: self.channel_capacity_msat, + opening_fee_params: self.opening_fee_params.clone(), + }], + }); + } + + // Keep collecting + Ok(ApplyResult { + events, + ..Default::default() + }) + } + (SessionState::Collecting { parts }, SessionInput::CollectTimeout) => { + // Session collection timed out: we fail the session but keep + // the offer active. Next payment can create a new session. + let n_parts = parts.len(); + let parts_sum = parts.iter().map(|p| p.amount_msat).sum(); + + self.state = SessionState::Failed; + Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ], + events: vec![ + SessionEvent::CollectTimeout { n_parts, parts_sum }, + SessionEvent::SessionFailed, + ], + }) + } + (SessionState::Collecting { parts }, SessionInput::NewBlock { height }) => { + let parts = parts.clone(); + Ok(self.check_cltv_timeout(&parts, height).unwrap_or_default()) + } + ( + SessionState::Collecting { .. }, + ref input @ (SessionInput::ChannelReady { .. } + | SessionInput::PaymentSettled + | SessionInput::PaymentFailed + | SessionInput::FundingBroadcasted + | SessionInput::FundingFailed + | SessionInput::ChannelClosed { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // AwaitChannelReady transitions. + // + (SessionState::AwaitingChannelReady { parts, .. }, SessionInput::AddPart { part }) => { + parts.push(part.clone()); + let n_parts = parts.len(); + let parts_sum = parts.iter().map(|p| p.amount_msat).sum(); + + // We don't check for max parts here as we are in the middle of + // the channel funding. We'll check once we transitioned. + + Ok(ApplyResult { + events: vec![SessionEvent::PaymentPartAdded { + part, + n_parts, + parts_sum, + }], + ..Default::default() + }) + } + ( + SessionState::AwaitingChannelReady { + parts, + opening_fee_msat, + }, + SessionInput::ChannelReady { + channel_id, + funding_psbt, + }, + ) => { + // We are transitioning in any case. + let parts = std::mem::take(parts); + let opening_fee_msat = std::mem::take(opening_fee_msat); + + let n_parts = parts.len(); + + let mut events = vec![SessionEvent::ChannelReady { + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }]; + + // Fail if we have too many parts. + if n_parts > self.max_parts { + self.state = SessionState::Abandoned; + events.push(SessionEvent::TooManyParts { n_parts }); + events.push(SessionEvent::SessionAbandoned); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + ], + events, + }); + } + + // Deduct opening_fee_msat. + let forwards = if let Ok(forwards) = allocate_forwards(&parts, opening_fee_msat) { + forwards + } else { + self.state = SessionState::Abandoned; + events.push(SessionEvent::SessionAbandoned); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + ], + events, + }); + }; + + let parts_sum = + Msat::from_msat(forwards.iter().map(|p| p.forward_msat + p.fee_msat).sum()); + + events.push(SessionEvent::ForwardHtlcs { + channel_id: channel_id.clone(), + n_parts, + parts_sum, + opening_fee_msat, + }); + + // Forward HTLCs and await settlement. + self.state = SessionState::AwaitingSettlement { + forwarded_parts: forwards.clone(), + forwarded_amount_msat: forwards.iter().map(|p| p.forward_msat).sum(), + deducted_fee_msat: forwards.iter().map(|p| p.fee_msat).sum(), + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }; + + return Ok(ApplyResult { + actions: vec![SessionAction::ForwardHtlcs { + parts: forwards, + channel_id, + }], + events, + }); + } + ( + SessionState::AwaitingChannelReady { .. }, + ref input @ SessionInput::CollectTimeout, + ) => { + // Collection timeout is only relevant as long as we are still + // collecting parts to cover the fee. Once we opened the channel + // we don't care anymore. + Ok(ApplyResult::unusual_input(&self.state, input)) + } + (SessionState::AwaitingChannelReady { .. }, SessionInput::FundingFailed) => { + self.state = SessionState::Failed; + Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ], + events: vec![SessionEvent::SessionFailed], + }) + } + ( + SessionState::AwaitingChannelReady { parts, .. }, + SessionInput::NewBlock { height }, + ) => { + let parts = parts.clone(); + Ok(self.check_cltv_timeout(&parts, height).unwrap_or_default()) + } + ( + SessionState::AwaitingChannelReady { .. }, + ref input @ (SessionInput::PaymentSettled + | SessionInput::PaymentFailed + | SessionInput::FundingBroadcasted + | SessionInput::ChannelClosed { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // AwaitingSettlement transitions. + // + ( + SessionState::AwaitingSettlement { + forwarded_parts, + forwarded_amount_msat, + deducted_fee_msat, + channel_id, + .. + }, + SessionInput::AddPart { part }, + ) => { + // We forward late-arriving parts immediately in this state. + let fp = ForwardPart { + htlc_id: part.htlc_id, + fee_msat: 0, + forward_msat: part.amount_msat.msat(), + }; + *forwarded_amount_msat += fp.forward_msat; + *deducted_fee_msat += fp.fee_msat; + forwarded_parts.push(fp.clone()); + + let n_parts = forwarded_parts.len(); + let parts_sum = Msat::from_msat(*forwarded_amount_msat + *deducted_fee_msat); + + // We don't check max_parts here as there is not much we can + // do about this at this stage, we definitely need a: + // TODO: Add integration test for #Htlcs > max_accepted_htlcs + + Ok(ApplyResult { + events: vec![ + SessionEvent::PaymentPartAdded { + part: part.clone(), + n_parts, + parts_sum, + }, + SessionEvent::ForwardHtlcs { + channel_id: channel_id.clone(), + n_parts: 1, + parts_sum: part.amount_msat, + opening_fee_msat: 0, + }, + ], + actions: vec![SessionAction::ForwardHtlcs { + parts: vec![fp], + channel_id: channel_id.clone(), + }], + }) + } + ( + SessionState::AwaitingSettlement { + forwarded_parts, + channel_id, + funding_psbt, + .. + }, + SessionInput::PaymentSettled, + ) => { + let channel_id = std::mem::take(channel_id); + let funding_psbt = std::mem::take(funding_psbt); + let parts = std::mem::take(forwarded_parts); + + self.state = SessionState::Broadcasting { + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }; + + Ok(ApplyResult { + actions: vec![SessionAction::BroadcastFundingTx { + channel_id, + funding_psbt, + }], + events: vec![SessionEvent::PaymentSettled { parts }], + }) + } + ( + SessionState::AwaitingSettlement { + channel_id, + funding_psbt, + .. + }, + SessionInput::PaymentFailed, + ) => { + let channel_id = std::mem::take(channel_id); + let funding_psbt = std::mem::take(funding_psbt); + + // Parts are already forwarded so we can't do anything here. + // Abandon session. + + self.state = SessionState::Abandoned; + + Ok(ApplyResult { + actions: vec![ + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + SessionAction::Disconnect, + ], + events: vec![ + SessionEvent::PaymentFailed, + SessionEvent::SessionAbandoned, + ], + }) + } + ( + SessionState::AwaitingSettlement { + channel_id, + funding_psbt, + .. + }, + SessionInput::ChannelClosed { + channel_id: closed_id, + }, + ) if closed_id == *channel_id => { + let channel_id = std::mem::take(channel_id); + let funding_psbt = std::mem::take(funding_psbt); + + self.state = SessionState::Abandoned; + + Ok(ApplyResult { + actions: vec![ + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + SessionAction::Disconnect, + ], + events: vec![ + SessionEvent::PaymentFailed, + SessionEvent::SessionAbandoned, + ], + }) + } + ( + SessionState::AwaitingSettlement { .. }, + ref input @ (SessionInput::CollectTimeout + | SessionInput::ChannelReady { .. } + | SessionInput::FundingFailed + | SessionInput::FundingBroadcasted + | SessionInput::ChannelClosed { .. } + | SessionInput::NewBlock { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // Broadcasting transitions. + // + (SessionState::Broadcasting { channel_id, .. }, SessionInput::AddPart { part }) => { + // We already successfully settled htlcs for this payment + // hash, we don't care about max_parts anymore (for whatever + // reason we are collecting more of the same payment hash) + let n_parts = 1; + let parts_sum = part.amount_msat; + + Ok(ApplyResult { + actions: vec![SessionAction::ForwardHtlcs { + parts: vec![ForwardPart { + htlc_id: part.htlc_id, + fee_msat: 0, + forward_msat: part.amount_msat.msat(), + }], + channel_id: channel_id.clone(), + }], + events: vec![ + SessionEvent::PaymentPartAdded { + part: part.clone(), + n_parts, + parts_sum, + }, + SessionEvent::ForwardHtlcs { + channel_id: channel_id.clone(), + n_parts, + parts_sum, + opening_fee_msat: 0, + }, + ], + }) + } + (SessionState::Broadcasting { funding_psbt, .. }, SessionInput::FundingBroadcasted) => { + let funding_psbt = std::mem::take(funding_psbt); + + self.state = SessionState::Succeeded; + Ok(ApplyResult { + actions: vec![], + events: vec![ + SessionEvent::FundingBroadcasted { funding_psbt }, + SessionEvent::SessionSucceeded, + ], + }) + } + ( + SessionState::Broadcasting { .. }, + ref input @ (SessionInput::CollectTimeout + | SessionInput::ChannelReady { .. } + | SessionInput::PaymentSettled + | SessionInput::FundingFailed + | SessionInput::PaymentFailed + | SessionInput::ChannelClosed { .. } + | SessionInput::NewBlock { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // Terminal states. + // + (SessionState::Failed | SessionState::Abandoned | SessionState::Succeeded, input) => { + return Err(Error::InvalidTransition { + state: self.state.clone(), + input, + }) + } + } + } +} + +fn max_deductible(parts: &[PaymentPart]) -> u128 { + parts + .iter() + .map(|p| u128::from(p.amount_msat.msat().saturating_sub(1))) + .sum() +} + +fn is_deductible(parts: &[PaymentPart], opening_fee_msat: u64) -> bool { + max_deductible(parts) >= u128::from(opening_fee_msat) +} + +fn allocate_forwards(parts: &[PaymentPart], opening_fee_msat: u64) -> Result> { + if !is_deductible(parts, opening_fee_msat) { + return Err(Error::InsufficientDeductibleCapacity { + opening_fee_msat, + deductible_capacity_msat: max_deductible(parts), + }); + } + + let mut remaining = opening_fee_msat; + let forwards: Vec = parts + .iter() + .map(|p| { + let amt = p.amount_msat.msat(); + let deduct = remaining.min(amt.saturating_sub(1)); + remaining -= deduct; + ForwardPart { + htlc_id: p.htlc_id, + fee_msat: deduct, + forward_msat: amt - deduct, + } + }) + .collect(); + + debug_assert_eq!(remaining, 0); + Ok(forwards) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::lsps0::Ppm; + use crate::proto::lsps2::Promise; + use chrono::{Duration, Utc}; + + fn part(htlc_id: u64, amount_msat: u64) -> PaymentPart { + PaymentPart { + htlc_id, + amount_msat: Msat::from_msat(amount_msat), + cltv_expiry: 100, + } + } + + fn part_with_cltv(htlc_id: u64, amount_msat: u64, cltv_expiry: u32) -> PaymentPart { + PaymentPart { + htlc_id, + amount_msat: Msat::from_msat(amount_msat), + cltv_expiry, + } + } + + fn opening_fee_params(min_fee_msat: u64, proportional_ppm: u32) -> OpeningFeeParams { + OpeningFeeParams { + min_fee_msat: Msat::from_msat(min_fee_msat), + proportional: Ppm::from_ppm(proportional_ppm), + valid_until: Utc::now() + Duration::hours(1), + min_lifetime: 144, + max_client_to_self_delay: 2016, + min_payment_size_msat: Msat::from_msat(1), + max_payment_size_msat: Msat::from_msat(u64::MAX), + promise: Promise("test-promise".to_owned()), + } + } + + fn session(max_parts: usize, payment_size_msat: Option, min_fee_msat: u64) -> Session { + Session { + state: SessionState::Collecting { parts: vec![] }, + max_parts, + opening_fee_params: opening_fee_params(min_fee_msat, 1_000), + payment_size_msat: payment_size_msat.map(Msat::from_msat), + channel_capacity_msat: Msat::from_msat(100_000_000), + peer_id: "peer-1".to_owned(), + } + } + + #[test] + fn collecting_add_part_emits_payment_part_added() { + let mut s = session(3, Some(2_000), 1); + let p = part(1, 1_000); + + let res = s.apply(SessionInput::AddPart { part: p.clone() }).unwrap(); + + assert!(res.actions.is_empty()); + assert_eq!( + res.events, + vec![SessionEvent::PaymentPartAdded { + part: p, + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }] + ); + } + + #[test] + fn collecting_below_expected_stays_collecting_no_actions() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + } + + #[test] + fn collecting_reaches_expected_transitions_and_funds_channel() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let res = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingChannelReady { .. })); + assert_eq!(res.actions.len(), 1); + match &res.actions[0] { + SessionAction::FundChannel { + peer_id, + channel_capacity_msat, + opening_fee_params, + } => { + assert_eq!(peer_id, "peer-1"); + assert_eq!(*channel_capacity_msat, Msat::from_msat(100_000_000)); + assert_eq!(opening_fee_params.min_fee_msat, Msat::from_msat(1)); + assert_eq!(opening_fee_params.proportional, Ppm::from_ppm(1_000)); + assert_eq!(opening_fee_params.min_payment_size_msat, Msat::from_msat(1)); + assert_eq!( + opening_fee_params.max_payment_size_msat, + Msat::from_msat(u64::MAX) + ); + assert_eq!( + opening_fee_params.promise, + Promise("test-promise".to_owned()) + ); + } + _ => panic!("expected FundChannel action"), + } + assert!(res.events.contains(&SessionEvent::FundingChannel)); + } + + #[test] + fn collecting_too_many_parts_emits_fail_action() { + let mut s = session(0, Some(1_000), 1); + + let res = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: part(1, 1_000), + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }, + SessionEvent::TooManyParts { n_parts: 1 }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER + }, + SessionAction::FailSession + ] + ); + } + + #[test] + fn collecting_insufficient_for_opening_fee_emits_fail_action() { + let mut s = session(3, Some(1_000), 1_000); + + let res = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: part(1, 1_000), + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }, + SessionEvent::PaymentInsufficientForOpeningFee { + opening_fee_msat: 1_000, + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER + }, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn collecting_collect_timeout_with_no_parts_fails_and_transitions_failed() { + let mut s = session(3, Some(2_000), 1); + + let res = s.apply(SessionInput::CollectTimeout).unwrap(); + + assert!(matches!(s.state, SessionState::Failed)); + assert_eq!( + res.events, + vec![ + SessionEvent::CollectTimeout { + n_parts: 0, + parts_sum: Msat::from_msat(0), + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn collecting_collect_timeout_with_parts_reports_count_and_sum() { + let mut s = session(3, Some(5_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 2_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + + let res = s.apply(SessionInput::CollectTimeout).unwrap(); + + assert!(matches!(s.state, SessionState::Failed)); + assert_eq!( + res.events, + vec![ + SessionEvent::CollectTimeout { + n_parts: 2, + parts_sum: Msat::from_msat(3_000), + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn failed_rejects_add_part_with_invalid_transition() { + let mut s = session(3, Some(2_000), 1); + s.state = SessionState::Failed; + + let err = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap_err(); + + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Failed, + input: SessionInput::AddPart { + part: part(1, 1_000), + }, + } + ); + } + + #[test] + fn failed_rejects_collect_timeout_with_invalid_transition() { + let mut s = session(3, Some(2_000), 1); + s.state = SessionState::Failed; + + let err = s.apply(SessionInput::CollectTimeout).unwrap_err(); + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Failed, + input: SessionInput::CollectTimeout, + } + ); + } + + #[test] + fn collecting_var_amount_single_htlc_triggers_funding() { + let mut s = session(3, None, 1); + let res = s + .apply(SessionInput::AddPart { + part: part(1, 10_000_000), + }) + .unwrap(); + + assert!(matches!( + s.state, + SessionState::AwaitingChannelReady { .. } + )); + assert!(res + .actions + .iter() + .any(|a| matches!(a, SessionAction::FundChannel { .. }))); + assert!(res + .events + .iter() + .any(|e| matches!(e, SessionEvent::FundingChannel))); + } + + #[test] + fn collecting_var_amount_second_htlc_fails() { + // Set up a session with one part already in Collecting + let mut s = session(3, None, 1); + s.state = SessionState::Collecting { + parts: vec![part(1, 5_000_000)], + }; + let res = s + .apply(SessionInput::AddPart { + part: part(2, 5_000_000), + }) + .unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert!(res + .events + .iter() + .any(|e| matches!(e, SessionEvent::TooManyParts { n_parts: 2 }))); + assert!(res + .actions + .iter() + .any(|a| matches!(a, SessionAction::FailHtlcs { .. }))); + } + + #[test] + fn collecting_var_amount_fee_computed_on_htlc_amount() { + let mut s = session(3, None, 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 10_000_000), + }) + .unwrap(); + + // fee = max(min_fee=1000, 10_000_000 * 1000 / 1_000_000) = max(1000, 10_000) = 10_000 + if let SessionState::AwaitingChannelReady { + opening_fee_msat, .. + } = s.state + { + assert_eq!(opening_fee_msat, 10_000); + } else { + panic!("expected AwaitingChannelReady, got {:?}", s.state); + } + } + + #[test] + fn collecting_fee_overflow_returns_fee_overflow() { + let mut s = session(3, Some(u64::MAX), 1); + s.opening_fee_params.proportional = Ppm::from_ppm(u32::MAX); + + let err = s + .apply(SessionInput::AddPart { + part: part(1, u64::MAX), + }) + .unwrap_err(); + assert_eq!(err, Error::FeeOverflow); + } + + #[test] + fn collecting_unexpected_inputs_emit_unusual_input() { + let mut s = session(3, Some(2_000), 1); + + let res = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + + #[test] + fn channel_ready_forwards_all_parts_and_transitions_to_awaiting_settlement() { + let mut s = session(4, Some(2_000), 1); + + let p1 = part(1, 1_000); + let p2 = part(2, 1_000); + let p3 = part(3, 500); + + let _ = s.apply(SessionInput::AddPart { part: p1.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p2.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p3.clone() }).unwrap(); + + let res = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + assert_eq!( + s.state, + SessionState::AwaitingSettlement { + forwarded_parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ForwardPart { + htlc_id: p3.htlc_id, + fee_msat: 0, + forward_msat: 500, + }, + ], + forwarded_amount_msat: 2_498, + deducted_fee_msat: 2, + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + } + ); + + assert_eq!( + res.actions, + vec![SessionAction::ForwardHtlcs { + parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ForwardPart { + htlc_id: p3.htlc_id, + fee_msat: 0, + forward_msat: 500, + }, + ], + channel_id: "chan-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }, + SessionEvent::ForwardHtlcs { + channel_id: "chan-1".to_owned(), + n_parts: 3, + parts_sum: Msat::from_msat(2_500), + opening_fee_msat: 2, + }, + ] + ); + } + + #[test] + fn awaiting_settlement_add_part_forwards_single_part() { + let mut s = session(5, Some(2_000), 1); + + let p1 = part(1, 1_000); + let p2 = part(2, 1_000); + let p3 = part(3, 500); + + let _ = s.apply(SessionInput::AddPart { part: p1.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p2.clone() }).unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::AddPart { part: p3.clone() }).unwrap(); + + assert_eq!( + s.state, + SessionState::AwaitingSettlement { + forwarded_parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ForwardPart { + htlc_id: p3.htlc_id, + fee_msat: 0, + forward_msat: 500, + }, + ], + forwarded_amount_msat: 2_498, + deducted_fee_msat: 2, + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + } + ); + + assert_eq!( + res.actions, + vec![SessionAction::ForwardHtlcs { + parts: vec![p3.clone().into()], + channel_id: "chan-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: p3.clone(), + n_parts: 3, + parts_sum: Msat::from_msat(2_500), + }, + SessionEvent::ForwardHtlcs { + channel_id: "chan-1".to_owned(), + n_parts: 1, + parts_sum: Msat::from_msat(500), + opening_fee_msat: 0, + }, + ] + ); + } + + #[test] + fn allocate_forwards_allows_exact_deductible_capacity() { + let parts = vec![part(1, 1_000), part(2, 1_000)]; + + let forwards = allocate_forwards(&parts, 1_998).unwrap(); + + assert_eq!( + forwards, + vec![ + ForwardPart { + htlc_id: 1, + fee_msat: 999, + forward_msat: 1, + }, + ForwardPart { + htlc_id: 2, + fee_msat: 999, + forward_msat: 1, + }, + ] + ); + } + + #[test] + fn payment_settled_transitions_to_broadcasting_and_emits_broadcast_action() { + let mut s = session(4, Some(2_000), 1); + + let p1 = part(1, 1_000); + let p2 = part(2, 1_000); + let _ = s.apply(SessionInput::AddPart { part: p1.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p2.clone() }).unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::PaymentSettled).unwrap(); + + assert_eq!( + s.state, + SessionState::Broadcasting { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + } + ); + assert_eq!( + res.actions, + vec![SessionAction::BroadcastFundingTx { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![SessionEvent::PaymentSettled { + parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ] + }] + ); + } + + #[test] + fn channel_ready_with_too_many_parts_abandons_session_and_fails_htlcs() { + let mut s = session(2, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + // Extra part while awaiting channel ready. + let _ = s + .apply(SessionInput::AddPart { part: part(3, 500) }) + .unwrap(); + + let res = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-overflow".to_owned(), + funding_psbt: "psbt-overflow".to_owned(), + }) + .unwrap(); + + assert_eq!(s.state, SessionState::Abandoned); + assert_eq!( + res.events, + vec![ + SessionEvent::ChannelReady { + channel_id: "chan-overflow".to_owned(), + funding_psbt: "psbt-overflow".to_owned(), + }, + SessionEvent::TooManyParts { n_parts: 3 }, + SessionEvent::SessionAbandoned, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::AbandonSession { + channel_id: "chan-overflow".to_owned(), + funding_psbt: "psbt-overflow".to_owned(), + }, + ] + ); + } + + #[test] + fn abandoned_rejects_further_inputs_with_invalid_transition() { + let mut s = session(2, Some(2_000), 1); + s.state = SessionState::Abandoned; + + let err = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap_err(); + + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Abandoned, + input: SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }, + } + ); + } + + #[test] + fn broadcasting_add_part_forwards_single_htlc() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + let _ = s.apply(SessionInput::PaymentSettled).unwrap(); + + let p3 = part(3, 500); + let res = s.apply(SessionInput::AddPart { part: p3.clone() }).unwrap(); + + assert_eq!( + res.actions, + vec![SessionAction::ForwardHtlcs { + parts: vec![p3.clone().into()], + channel_id: "chan-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: p3.clone(), + n_parts: 1, + parts_sum: Msat::from_msat(500), + }, + SessionEvent::ForwardHtlcs { + channel_id: "chan-1".to_owned(), + n_parts: 1, + parts_sum: Msat::from_msat(500), + opening_fee_msat: 0, + }, + ] + ); + } + + #[test] + fn funding_broadcasted_transitions_to_succeeded() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + let _ = s.apply(SessionInput::PaymentSettled).unwrap(); + + let res = s.apply(SessionInput::FundingBroadcasted).unwrap(); + + assert_eq!(s.state, SessionState::Succeeded); + assert_eq!(res.actions, vec![]); + assert_eq!( + res.events, + vec![ + SessionEvent::FundingBroadcasted { + funding_psbt: "psbt-1".to_owned(), + }, + SessionEvent::SessionSucceeded, + ] + ); + } + + #[test] + fn succeeded_rejects_new_inputs_with_invalid_transition() { + let mut s = session(4, Some(2_000), 1); + s.state = SessionState::Succeeded; + + let err = s + .apply(SessionInput::AddPart { + part: part(99, 1_000), + }) + .unwrap_err(); + + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Succeeded, + input: SessionInput::AddPart { + part: part(99, 1_000), + }, + } + ); + } + + #[test] + fn funding_failed_in_awaiting_channel_ready_fails_htlcs_and_transitions_to_failed() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingChannelReady { .. })); + + let res = s.apply(SessionInput::FundingFailed).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ] + ); + assert_eq!(res.events, vec![SessionEvent::SessionFailed]); + } + + #[test] + fn funding_failed_in_awaiting_channel_ready_with_extra_parts_reports_all() { + let mut s = session(5, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + // Extra part arrived while awaiting channel ready. + let _ = s + .apply(SessionInput::AddPart { part: part(3, 500) }) + .unwrap(); + + let res = s.apply(SessionInput::FundingFailed).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!(res.events, vec![SessionEvent::SessionFailed]); + } + + #[test] + fn collecting_unexpected_funding_failed_emits_unusual_input() { + let mut s = session(3, Some(2_000), 1); + + let res = s.apply(SessionInput::FundingFailed).unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + + #[test] + fn funding_failed_is_terminal() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s.apply(SessionInput::FundingFailed).unwrap(); + + assert!(s.is_terminal()); + + let err = s + .apply(SessionInput::AddPart { part: part(3, 500) }) + .unwrap_err(); + assert!(matches!(err, Error::InvalidTransition { .. })); + } + + #[test] + fn new_block_collecting_timeout_fails_session() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(1, 1_000, 50), + }) + .unwrap(); + + let res = s.apply(SessionInput::NewBlock { height: 51 }).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!( + res.events, + vec![ + SessionEvent::UnsafeHtlcTimeout { + height: 51, + cltv_min: 50, + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn new_block_collecting_safe_height_is_noop() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(1, 1_000, 50), + }) + .unwrap(); + + let res = s.apply(SessionInput::NewBlock { height: 49 }).unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert!(res.events.is_empty()); + } + + #[test] + fn new_block_collecting_no_parts_is_noop() { + let mut s = session(3, Some(2_000), 1); + + let res = s.apply(SessionInput::NewBlock { height: 100 }).unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert!(res.events.is_empty()); + } + + #[test] + fn new_block_awaiting_channel_ready_timeout_fails_with_disconnect() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(1, 1_000, 50), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(2, 1_000, 60), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingChannelReady { .. })); + + let res = s.apply(SessionInput::NewBlock { height: 51 }).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::UnsafeHtlcTimeout { + height: 51, + cltv_min: 50, + }, + SessionEvent::SessionFailed, + ] + ); + } + + #[test] + fn new_block_awaiting_settlement_emits_unusual_input() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::NewBlock { height: 200 }).unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingSettlement { .. })); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + + #[test] + fn awaiting_settlement_payment_failed_disconnects() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::PaymentFailed).unwrap(); + + assert_eq!(s.state, SessionState::Abandoned); + assert_eq!( + res.actions, + vec![ + SessionAction::AbandonSession { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }, + SessionAction::Disconnect, + ] + ); + assert_eq!( + res.events, + vec![SessionEvent::PaymentFailed, SessionEvent::SessionAbandoned] + ); + } + + #[test] + fn awaiting_settlement_unusual_inputs_emit_unusual_input() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + for input in [ + SessionInput::CollectTimeout, + SessionInput::FundingFailed, + SessionInput::FundingBroadcasted, + SessionInput::NewBlock { height: 100 }, + ] { + let res = s.apply(input).unwrap(); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + } + + #[test] + fn broadcasting_unusual_inputs_emit_unusual_input() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + let _ = s.apply(SessionInput::PaymentSettled).unwrap(); + + for input in [ + SessionInput::CollectTimeout, + SessionInput::FundingFailed, + SessionInput::PaymentFailed, + SessionInput::NewBlock { height: 100 }, + ] { + let res = s.apply(input).unwrap(); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + } + + #[test] + fn recover_without_preimage_enters_awaiting_settlement() { + let (session, actions) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + None, + opening_fee_params(1_000, 0), + ); + assert!(actions.is_empty()); + assert!(!session.is_terminal()); + } + + #[test] + fn recover_with_preimage_enters_broadcasting() { + let (session, actions) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + Some("preimage-1".to_string()), + opening_fee_params(1_000, 0), + ); + assert_eq!(actions.len(), 1); + assert!(matches!( + &actions[0], + SessionAction::BroadcastFundingTx { channel_id, funding_psbt } + if channel_id == "channel-id-1" && funding_psbt == "psbt-1" + )); + assert!(!session.is_terminal()); + } + + #[test] + fn recovered_awaiting_settlement_transitions_on_payment_settled() { + let (mut session, _) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + None, + opening_fee_params(1_000, 0), + ); + let result = session.apply(SessionInput::PaymentSettled).unwrap(); + assert!(matches!( + result.actions.as_slice(), + [SessionAction::BroadcastFundingTx { .. }] + )); + } + + #[test] + fn recovered_awaiting_settlement_transitions_on_payment_failed() { + let (mut session, _) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + None, + opening_fee_params(1_000, 0), + ); + let result = session.apply(SessionInput::PaymentFailed).unwrap(); + assert!(matches!( + result.actions.as_slice(), + [SessionAction::AbandonSession { .. }, SessionAction::Disconnect] + )); + assert!(session.is_terminal()); + } + + #[test] + fn recovered_broadcasting_transitions_on_funding_broadcasted() { + let (mut session, _) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + Some("preimage-1".to_string()), + opening_fee_params(1_000, 0), + ); + let result = session.apply(SessionInput::FundingBroadcasted).unwrap(); + let _ = result; + assert!(session.is_terminal()); + assert_eq!(session.outcome(), Some(SessionOutcome::Succeeded)); + } +} diff --git a/plugins/lsps-plugin/src/lib.rs b/plugins/lsps-plugin/src/lib.rs index e1f5e07f4303..72174fc05f04 100644 --- a/plugins/lsps-plugin/src/lib.rs +++ b/plugins/lsps-plugin/src/lib.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "cln")] pub mod cln_adapters; pub mod core; pub mod proto; diff --git a/plugins/lsps-plugin/src/proto/lsps0.rs b/plugins/lsps-plugin/src/proto/lsps0.rs index 2cb72812931f..20d6f2857268 100644 --- a/plugins/lsps-plugin/src/proto/lsps0.rs +++ b/plugins/lsps-plugin/src/proto/lsps0.rs @@ -1,6 +1,7 @@ use crate::proto::jsonrpc::{JsonRpcRequest, RpcError}; use core::fmt; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use std::iter::Sum; use thiserror::Error; const MSAT_PER_SAT: u64 = 1_000; @@ -100,6 +101,12 @@ impl core::fmt::Display for Msat { } } +impl Sum for Msat { + fn sum>(iter: I) -> Self { + Msat(iter.map(|x| x.0).sum()) + } +} + impl Serialize for Msat { fn serialize(&self, serializer: S) -> std::result::Result where @@ -190,9 +197,79 @@ impl core::fmt::Display for Ppm { } } -/// Represents a short channel id as defined in LSPS0.scid. Matches with the -/// implementation in cln_rpc. -pub type ShortChannelId = cln_rpc::primitives::ShortChannelId; +/// Represents a short channel id as defined in LSPS0.scid. +/// Format: `{block}x{txindex}x{outnum}` encoding a u64 as +/// `(block << 40) | (txindex << 16) | outnum`. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ShortChannelId(u64); + +impl ShortChannelId { + pub fn block(&self) -> u32 { + (self.0 >> 40) as u32 & 0xFFFFFF + } + pub fn txindex(&self) -> u32 { + (self.0 >> 16) as u32 & 0xFFFFFF + } + pub fn outnum(&self) -> u16 { + self.0 as u16 & 0xFFFF + } + pub fn to_u64(&self) -> u64 { + self.0 + } +} + +impl From for ShortChannelId { + fn from(v: u64) -> Self { + ShortChannelId(v) + } +} + +impl core::fmt::Display for ShortChannelId { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}x{}x{}", self.block(), self.txindex(), self.outnum()) + } +} + +impl core::str::FromStr for ShortChannelId { + type Err = String; + fn from_str(s: &str) -> std::result::Result { + let parts: Vec<&str> = s.split('x').collect(); + if parts.len() != 3 { + return Err(format!("Malformed short_channel_id: expected 3 parts, got {}", parts.len())); + } + let block: u64 = parts[0].parse().map_err(|e| format!("bad block: {e}"))?; + let txindex: u64 = parts[1].parse().map_err(|e| format!("bad txindex: {e}"))?; + let outnum: u64 = parts[2].parse().map_err(|e| format!("bad outnum: {e}"))?; + Ok(ShortChannelId((block << 40) | (txindex << 16) | outnum)) + } +} + +impl serde::Serialize for ShortChannelId { + fn serialize(&self, serializer: S) -> std::result::Result { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> serde::Deserialize<'de> for ShortChannelId { + fn deserialize>(deserializer: D) -> std::result::Result { + let s: String = serde::Deserialize::deserialize(deserializer)?; + s.parse().map_err(serde::de::Error::custom) + } +} + +#[cfg(feature = "cln")] +impl From for cln_rpc::primitives::ShortChannelId { + fn from(scid: ShortChannelId) -> Self { + cln_rpc::primitives::ShortChannelId::from(scid.0) + } +} + +#[cfg(feature = "cln")] +impl From for ShortChannelId { + fn from(scid: cln_rpc::primitives::ShortChannelId) -> Self { + ShortChannelId(scid.to_u64()) + } +} /// Represents a datetime as defined in LSPS0.datetime. Uses ISO8601 in UTC /// timezone. diff --git a/plugins/lsps-plugin/src/proto/lsps2.rs b/plugins/lsps-plugin/src/proto/lsps2.rs index 82767a27ea53..a8c3637b19c6 100644 --- a/plugins/lsps-plugin/src/proto/lsps2.rs +++ b/plugins/lsps-plugin/src/proto/lsps2.rs @@ -113,7 +113,7 @@ impl core::fmt::Display for PromiseError { impl core::error::Error for PromiseError {} -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Eq)] #[serde(try_from = "String")] pub struct Promise(pub String); @@ -161,7 +161,7 @@ impl core::fmt::Display for Promise { /// Represents a set of parameters for calculating the opening fee for a JIT /// channel. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(deny_unknown_fields)] // LSPS2 requires the client to fail if a field is unrecognized. pub struct OpeningFeeParams { pub min_fee_msat: Msat, @@ -280,15 +280,14 @@ pub struct Lsps2PolicyGetInfoResponse { } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Lsps2PolicyGetChannelCapacityRequest { +pub struct Lsps2PolicyBuyRequest { pub opening_fee_params: OpeningFeeParams, - pub init_payment_size: Msat, - pub scid: ShortChannelId, + pub payment_size_msat: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Lsps2PolicyGetChannelCapacityResponse { - pub channel_capacity_msat: Option, +pub struct Lsps2PolicyBuyResponse { + pub channel_capacity_msat: Option, } /// An internal representation of a policy of parameters for calculating the @@ -338,10 +337,43 @@ impl PolicyOpeningFeeParams { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct DatastoreEntry { - pub peer_id: cln_rpc::primitives::PublicKey, + pub peer_id: bitcoin::secp256k1::PublicKey, pub opening_fee_params: OpeningFeeParams, #[serde(skip_serializing_if = "Option::is_none")] pub expected_payment_size: Option, + pub channel_capacity_msat: Msat, + pub created_at: DateTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub channel_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub funding_psbt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub funding_txid: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub preimage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub forwards_updated_index: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub payment_hash: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum SessionOutcome { + Succeeded, + Abandoned, + Failed, + Timeout, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct FinalizedDatastoreEntry { + #[serde(flatten)] + pub entry: DatastoreEntry, + pub outcome: SessionOutcome, + pub finalized_at: DateTime, } /// Computes the opening fee in millisatoshis as described in LSPS2. diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index 2e9ae10c28ce..58a22c4637a1 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -1,21 +1,35 @@ use anyhow::bail; use bitcoin::hashes::Hash; +use chrono::Utc; use cln_lsps::{ cln_adapters::{ - hooks::service_custommsg_hook, rpc::ClnApiRpc, sender::ClnSender, state::ServiceState, + hooks::service_custommsg_hook, + rpc::{ + ClnActionExecutor, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, ClnRpcClient, + }, + sender::ClnSender, + state::ServiceState, types::HtlcAcceptedRequest, }, core::{ lsps2::{ - htlc::{Htlc, HtlcAcceptedHookHandler, HtlcDecision, Onion, RejectReason}, + actor::HtlcResponse, + event_sink::NoopEventSink, + manager::{PaymentHash, SessionConfig, SessionManager}, + provider::{DatastoreProvider, RecoveryProvider}, service::Lsps2ServiceHandler, + session::PaymentPart, }, server::LspsService, + tlv::{TlvStream, TLV_FORWARD_AMT}, + }, + proto::{ + lsps0::{Msat, ShortChannelId}, + lsps2::{failure_codes::UNKNOWN_NEXT_PEER, SessionOutcome}, }, - proto::lsps0::{Msat, LSPS0_MESSAGE_TYPE}, }; -use cln_plugin::{options, HookBuilder, HookFilter, Plugin}; -use log::{debug, error, trace}; +use cln_plugin::{options, Plugin}; +use log::{debug, error, trace, warn}; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -30,23 +44,53 @@ pub const OPTION_PROMISE_SECRET: options::StringConfigOption = "A 64-character hex string that is the secret for promises", ); +pub const OPTION_COLLECT_TIMEOUT: options::DefaultIntegerConfigOption = + options::ConfigOption::new_i64_with_default( + "dev-lsps2-collect-timeout", + 90, + "Timeout in seconds for collecting MPP parts (default: 90)", + ); + #[derive(Clone)] struct State { lsps_service: Arc, sender: ClnSender, lsps2_enabled: bool, + datastore: Arc, + recovery: Arc, + session_manager: Arc>, } impl State { - pub fn new(rpc_path: PathBuf, promise_secret: &[u8; 32]) -> Self { - let api = Arc::new(ClnApiRpc::new(rpc_path.clone())); + pub fn new(rpc_path: PathBuf, promise_secret: &[u8; 32], collect_timeout_secs: u64) -> Self { + let rpc = ClnRpcClient::new(rpc_path.clone()); let sender = ClnSender::new(rpc_path); - let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(api, promise_secret)); + let datastore = Arc::new(ClnDatastore::new(rpc.clone())); + let policy = Arc::new(ClnPolicyProvider::new(rpc.clone())); + let executor = Arc::new(ClnActionExecutor::new(rpc.clone())); + let recovery = Arc::new(ClnRecoveryProvider::new(rpc)); + let lsps2_handler = Arc::new(Lsps2ServiceHandler::new( + datastore.clone(), + policy, + promise_secret, + )); let lsps_service = Arc::new(LspsService::builder().with_protocol(lsps2_handler).build()); + let session_manager = Arc::new(SessionManager::new( + datastore.clone(), + executor, + SessionConfig { + collect_timeout_secs, + ..SessionConfig::default() + }, + Arc::new(NoopEventSink), + )); Self { lsps_service, sender, lsps2_enabled: true, + datastore, + recovery, + session_manager, } } } @@ -66,6 +110,7 @@ async fn main() -> Result<(), anyhow::Error> { if let Some(plugin) = cln_plugin::Builder::new(tokio::io::stdin(), tokio::io::stdout()) .option(OPTION_ENABLED) .option(OPTION_PROMISE_SECRET) + .option(OPTION_COLLECT_TIMEOUT) // FIXME: Temporarily disabled lsp feature to please test cases, this is // ok as the feature is optional per spec. // We need to ensure that `connectd` only starts after all plugins have @@ -78,11 +123,10 @@ async fn main() -> Result<(), anyhow::Error> { // cln_plugin::FeatureBitsKind::Init, // util::feature_bit_to_hex(LSP_FEATURE_BIT), // ) - .hook_from_builder( - HookBuilder::new("custommsg", service_custommsg_hook) - .filters(vec![HookFilter::Int(i64::from(LSPS0_MESSAGE_TYPE))]), - ) + .hook("custommsg", service_custommsg_hook) .hook("htlc_accepted", on_htlc_accepted) + .subscribe("forward_event", on_forward_event) + .subscribe("block_added", on_block_added) .configure() .await? { @@ -118,7 +162,15 @@ async fn main() -> Result<(), anyhow::Error> { } }; - let state = State::new(rpc_path, &secret); + let collect_timeout_secs = plugin.option(&OPTION_COLLECT_TIMEOUT)? as u64; + let state = State::new(rpc_path, &secret, collect_timeout_secs); + + // Recover in-flight sessions before processing replayed HTLCs + let recovery: Arc = state.recovery.clone(); + if let Err(e) = state.session_manager.recover(recovery).await { + warn!("session recovery failed: {e}"); + } + let plugin = plugin.start(state).await?; plugin.join().await } else { @@ -161,65 +213,164 @@ async fn handle_htlc_inner( let req: HtlcAcceptedRequest = serde_json::from_value(v)?; - let short_channel_id = match req.onion.short_channel_id { - Some(scid) => scid, + let short_channel_id: ShortChannelId = match req.onion.short_channel_id { + Some(scid) => scid.into(), None => { trace!("We are the destination of the HTLC, continue."); return Ok(json_continue()); } }; - let rpc_path = Path::new(&p.configuration().lightning_dir).join(&p.configuration().rpc_file); - let api = ClnApiRpc::new(rpc_path); - // Fixme: Use real htlc_minimum_amount. - let handler = HtlcAcceptedHookHandler::new(api, 1000); - - let onion = Onion { - short_channel_id, - payload: req.onion.payload, + // Decide path: look up buy request to check for MPP. + let ds_rec = match p.state().datastore.get_buy_request(&short_channel_id).await { + Ok(rec) => rec, + Err(_) => { + trace!("SCID not ours, continue."); + return Ok(json_continue()); + } }; - let htlc = Htlc { + if Utc::now() >= ds_rec.opening_fee_params.valid_until { + let _ = p + .state() + .datastore + .finalize_session(&short_channel_id, SessionOutcome::Timeout) + .await; + return Ok(json_fail(UNKNOWN_NEXT_PEER)); + } + + handle_session_htlc(p, &req, short_channel_id).await +} + +async fn handle_session_htlc( + p: &Plugin, + req: &HtlcAcceptedRequest, + scid: ShortChannelId, +) -> Result { + let payment_hash = PaymentHash::from_byte_array(req.htlc.payment_hash.as_slice().try_into()?); + let part = PaymentPart { + htlc_id: req.htlc.id, amount_msat: Msat::from_msat(req.htlc.amount_msat.msat()), - extra_tlvs: req.htlc.extra_tlvs.unwrap_or_default(), + cltv_expiry: req.htlc.cltv_expiry, }; + match p + .state() + .session_manager + .on_part(payment_hash, scid, part) + .await + { + Ok(resp) => session_response_to_json( + resp, + &req.onion.payload, + req.htlc.amount_msat.msat(), + &req.htlc.extra_tlvs, + ), + Err(e) => { + debug!("session manager error: {e:#}"); + Ok(json_continue()) + } + } +} + +fn session_response_to_json( + resp: HtlcResponse, + payload: &TlvStream, + _htlc_amount_msat: u64, + extra_tlvs: &Option, +) -> Result { + match resp { + HtlcResponse::Forward { + channel_id, + fee_msat, + forward_msat, + } => { + let mut payload = payload.clone(); + payload.set_tu64(TLV_FORWARD_AMT, forward_msat); - debug!("Handle potential jit-session HTLC."); - let response = match handler.handle(&htlc, &onion).await { - Ok(dec) => { - log_decision(&dec); - decision_to_response(dec)? + let mut extra_tlvs = extra_tlvs.clone().unwrap_or_default(); + extra_tlvs.set_u64(65537, fee_msat); + + let forward_to = hex::decode(&channel_id)?; + + Ok(json_continue_forward( + payload.to_bytes()?, + forward_to, + extra_tlvs.to_bytes()?, + )) } - Err(e) => { - // Fixme: Should we log **BROKEN** here? - debug!("Htlc handler failed (continuing): {:#}", e); - return Ok(json_continue()); + HtlcResponse::Fail { failure_code } => Ok(json_fail(failure_code)), + HtlcResponse::Continue => Ok(json_continue()), + } +} + +async fn on_forward_event(p: Plugin, v: serde_json::Value) -> Result<(), anyhow::Error> { + let event = match v.get("forward_event") { + Some(e) => e, + None => return Ok(()), + }; + + let status = event.get("status").and_then(|s| s.as_str()); + + let payment_hash = match status { + Some("settled") | Some("failed") | Some("local_failed") => { + let hash_hex = match event.get("payment_hash").and_then(|s| s.as_str()) { + Some(h) => h, + None => return Ok(()), + }; + let bytes: [u8; 32] = hex::decode(hash_hex)? + .try_into() + .map_err(|v: Vec| anyhow::anyhow!("bad payment_hash len {}", v.len()))?; + PaymentHash::from_byte_array(bytes) } + _ => return Ok(()), }; - Ok(serde_json::to_value(&response)?) + let updated_index = event.get("updated_index").and_then(|v| v.as_u64()); + + match status { + Some("settled") => { + let preimage = event + .get("preimage") + .and_then(|s| s.as_str()) + .map(|s| s.to_string()); + + if let Err(e) = p + .state() + .session_manager + .on_payment_settled(payment_hash, preimage, updated_index) + .await + { + debug!("on_payment_settled error: {e:#}"); + } + } + Some("failed") | Some("local_failed") => { + if let Err(e) = p + .state() + .session_manager + .on_payment_failed(payment_hash, updated_index) + .await + { + debug!("on_payment_failed error: {e:#}"); + } + } + _ => unreachable!(), + } + + Ok(()) } -fn decision_to_response(decision: HtlcDecision) -> Result { - Ok(match decision { - HtlcDecision::NotOurs => json_continue(), - - HtlcDecision::Forward { - mut payload, - forward_to, - mut extra_tlvs, - } => json_continue_forward( - payload.to_bytes()?, - forward_to.as_byte_array().to_vec(), - extra_tlvs.to_bytes()?, - ), +async fn on_block_added(p: Plugin, v: serde_json::Value) -> Result<(), anyhow::Error> { + let height = match v + .get("block_added") + .and_then(|b| b.get("height")) + .and_then(|h| h.as_u64()) + { + Some(h) => h as u32, + None => return Ok(()), + }; - // Fixme: once we implement MPP-Support we need to remove this. - HtlcDecision::Reject { - reason: RejectReason::MppNotSupported, - } => json_continue(), - HtlcDecision::Reject { reason } => json_fail(reason.failure_code()), - }) + p.state().session_manager.on_new_block(height).await; + Ok(()) } fn json_continue() -> serde_json::Value { @@ -245,20 +396,3 @@ fn json_fail(failure_code: &str) -> serde_json::Value { "failure_message": failure_code }) } - -fn log_decision(decision: &HtlcDecision) { - match decision { - HtlcDecision::NotOurs => { - trace!("SCID not ours, continue"); - } - HtlcDecision::Forward { forward_to, .. } => { - debug!( - "Forwarding via JIT channel {}", - hex::encode(forward_to.as_byte_array()) - ); - } - HtlcDecision::Reject { reason } => { - debug!("Rejecting HTLC: {:?}", reason); - } - } -} diff --git a/tests/plugins/lsps2_policy.py b/tests/plugins/lsps2_policy.py index d71fc67035d9..a88dbbf24874 100755 --- a/tests/plugins/lsps2_policy.py +++ b/tests/plugins/lsps2_policy.py @@ -42,10 +42,10 @@ def lsps2_policy_getpolicy(request): } -@plugin.method("lsps2-policy-getchannelcapacity") -def lsps2_policy_getchannelcapacity(request, init_payment_size, scid, opening_fee_params): - """Returns an opening fee menu for the LSPS2 plugin.""" - return {"channel_capacity_msat": 100000000} +@plugin.method("lsps2-policy-buy") +def lsps2_policy_buy(request, opening_fee_params, payment_size_msat=None): + """Returns the channel capacity for a buy request.""" + return {"channel_capacity_msat": "100000000"} plugin.run() diff --git a/tests/plugins/lsps2_service_mock.py b/tests/plugins/lsps2_service_mock.py deleted file mode 100755 index fecd1a58baa2..000000000000 --- a/tests/plugins/lsps2_service_mock.py +++ /dev/null @@ -1,205 +0,0 @@ -#!/usr/bin/env python3 -""" -Zero‑conf LSPS2 mock -==================== - -• On the **first incoming HTLC**, call `connect` and `fundchannel` with **zeroconf** to a configured peer. -• **Hold all HTLCs** until the channel reports `CHANNELD_NORMAL`, then **continue** them all. -• After the channel is ready, future HTLCs are continued immediately. -""" - -import threading -import time -import struct -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone -from typing import Dict, Optional -from pyln.client import Plugin -from pyln.proto.onion import TlvPayload - - -plugin = Plugin() - - -@plugin.method("lsps2-policy-getpolicy") -def lsps2_policy_getpolicy(request): - """Returns an opening fee menu for the LSPS2 plugin.""" - now = datetime.now(timezone.utc) - - # Is ISO 8601 format "YYYY-MM-DDThh:mm:ss.uuuZ" - valid_until = (now + timedelta(hours=1)).isoformat().replace("+00:00", "Z") - - return { - "policy_opening_fee_params_menu": [ - { - "min_fee_msat": "1000000", - "proportional": 0, - "valid_until": valid_until, - "min_lifetime": 2000, - "max_client_to_self_delay": 2016, - "min_payment_size_msat": "1000", - "max_payment_size_msat": "100000000", - }, - ] - } - - -@plugin.method("lsps2-policy-getchannelcapacity") -def lsps2_policy_getchannelcapacity( - request, init_payment_size, scid, opening_fee_params -): - """Returns an opening fee menu for the LSPS2 plugin.""" - return {"channel_capacity_msat": 100000000} - - -TLV_OPENING_FEE = 65537 - - -@dataclass -class Held: - htlc: dict - onion: dict - event: threading.Event = field(default_factory=threading.Event) - response: Optional[dict] = None - - -@dataclass -class State: - target_peer: Optional[str] = None - channel_cap: Optional[int] = None - opening_fee_msat: Optional[int] = None - pending: Dict[str, Held] = field(default_factory=dict) - funding_started: bool = False - channel_ready: bool = False - channel_id_hex: Optional[str] = None - fee_remaining_msat: int = 0 - worker_thread: Optional[threading.Thread] = None - lock: threading.Lock = field(default_factory=threading.Lock) - - -state = State() - - -def _key(h: dict) -> str: - return f"{h.get('id', '?')}:{h.get('payment_hash', '?')}" - - -def _ensure_zero_conf_channel(peer_id: str, capacity: int) -> bool: - plugin.log(f"fundchannel zero-conf to {peer_id} for {capacity} sat...") - res = plugin.rpc.fundchannel( - peer_id, - capacity, - announce=False, - mindepth=0, - channel_type=[12, 46, 50], - ) - plugin.log(f"got channel response {res}") - state.channel_id_hex = res["channel_id"] - - for _ in range(120): - channels = plugin.rpc.listpeerchannels(peer_id)["channels"] - for c in channels: - if c.get("state") == "CHANNELD_NORMAL": - plugin.log("zero-conf channel is NORMAL; releaseing HTLCs") - return True - time.sleep(1) - return False - - -def _modify_payload_and_build_response(held: Held): - amt_msat = int(held.htlc.get("amount_msat", 0)) - fee_applied = 0 - if state.fee_remaining_msat > 0: - fee_applied = min(state.fee_remaining_msat, max(amt_msat - 1, 0)) - state.fee_remaining_msat -= fee_applied - forward_msat = max(1, amt_msat - fee_applied) - - payload = None - extra = None - if amt_msat != forward_msat: - amt_byte = struct.pack("!Q", forward_msat) - while len(amt_byte) > 1 and amt_byte[0] == 0: - amt_byte = amt_byte[1:] - payload = TlvPayload().from_hex(held.onion["payload"]) - p = TlvPayload() - p.add_field(2, amt_byte) - p.add_field(4, payload.get(4).value) - p.add_field(6, payload.get(6).value) - payload = p.to_bytes(include_prefix=False) - - amt_byte = fee_applied.to_bytes(8, "big") - e = TlvPayload() - e.add_field(TLV_OPENING_FEE, amt_byte) - extra = e.to_bytes(include_prefix=False) - - resp = {"result": "continue"} - if payload: - resp["payload"] = payload.hex() - if extra: - resp["extra_tlvs"] = extra.hex() - if state.channel_id_hex: - resp["forward_to"] = state.channel_id_hex - return resp - - -def _release_all_locked(): - # called with state.lock held - items = list(state.pending.items()) - state.pending.clear() - for _k, held in items: - if held.response is None: - held.response = _modify_payload_and_build_response(held) - held.event.set() - - -def _worker(): - plugin.log("collecting htlcs and fund channel...") - with state.lock: - peer = state.target_peer - cap = state.channel_cap - fee = state.opening_fee_msat - if not peer or not cap or not fee: - with state.lock: - _release_all_locked() - return - - ok = _ensure_zero_conf_channel(peer, cap) - with state.lock: - state.channel_ready = ok - state.fee_remaining_msat = fee if ok else 0 - _release_all_locked() - - -@plugin.method("setuplsps2service") -def setuplsps2service(plugin, peer_id, channel_cap, opening_fee_msat): - state.target_peer = peer_id - state.channel_cap = channel_cap - state.opening_fee_msat = opening_fee_msat - - -@plugin.async_hook("htlc_accepted") -def on_htlc_accepted(htlc, onion, request, plugin, **kwargs): - key = _key(htlc) - - with state.lock: - if state.channel_ready: - held_now = Held(htlc=htlc, onion=onion) - resp = _modify_payload_and_build_response(held_now) - request.set_result(resp) - return - - if not state.funding_started: - state.funding_started = True - state.worker_thread = threading.Thread(target=_worker, daemon=True) - state.worker_thread.start() - - # enqueue and block until the worker releases us - held = Held(htlc=htlc, onion=onion) - state.pending[key] = held - - held.event.wait() - request.set_result(held.response) - - -if __name__ == "__main__": - plugin.run() diff --git a/tests/test_cln_lsps.py b/tests/test_cln_lsps.py index 7eb2d4ffd197..f87294c1f520 100644 --- a/tests/test_cln_lsps.py +++ b/tests/test_cln_lsps.py @@ -1,11 +1,100 @@ -from fixtures import * # noqa: F401,F403 -from pyln.testing.utils import RUST -from utils import only_one +import json import os -import pytest +import time import unittest +import pytest +from fixtures import * # noqa: F401,F403 +from pyln.testing.utils import RUST, wait_for +from utils import only_one + RUST_PROFILE = os.environ.get("RUST_PROFILE", "debug") +POLICY_PLUGIN = os.path.join(os.path.dirname(__file__), "plugins/lsps2_policy.py") +LSP_OPTS = { + "experimental-lsps2-service": None, + "experimental-lsps2-promise-secret": "0" * 64, + "dev-lsps2-collect-timeout": 5, + "plugin": POLICY_PLUGIN, + "fee-base": 0, + "fee-per-satoshi": 0, +} + + +def setup_lsps2_network( + node_factory, bitcoind, lsp_opts=None, client_opts=None, may_reconnect=False +): + """Create l1 (client), l2 (LSP), l3 (payer) with l3--l2 funded. + + Returns (l1, l2, l3, chanid) where chanid is the l3-l2 channel. + """ + opts = lsp_opts or LSP_OPTS + client = client_opts or {} + l1_opts = {"experimental-lsps-client": None, **client} + if may_reconnect: + l1_opts["may_reconnect"] = True + opts = {**opts, "may_reconnect": True} + l1, l2, l3 = node_factory.get_nodes( + 3, + opts=[ + l1_opts, + opts, + {"may_reconnect": True} if may_reconnect else {}, + ], + ) + + l2.fundwallet(1_000_000) + node_factory.join_nodes([l3, l2], fundchannel=True, wait_for_announce=True) + node_factory.join_nodes([l1, l2], fundchannel=False) + + chanid = only_one(l3.rpc.listpeerchannels(l2.info["id"])["channels"])[ + "short_channel_id" + ] + return l1, l2, l3, chanid + + +def buy_and_invoice(l1, l2, amt): + """Buy a JIT channel and create a fixed-amount invoice. + + Returns (dec, inv) where dec is the decoded invoice dict. + """ + inv = l1.rpc.lsps_lsps2_invoice( + lsp_id=l2.info["id"], + amount_msat=f"{amt}msat", + description="lsp-jit-channel", + label=f"lsp-jit-channel-{time.monotonic_ns()}", + ) + dec = l2.rpc.decode(inv["bolt11"]) + return dec, inv + + +def send_mpp(l3, l2_id, l1_id, chanid, dec, inv, amt, parts): + """Send an MPP payment split into equal parts via sendpay.""" + routehint = only_one(only_one(dec["routes"])) + route_part = [ + { + "amount_msat": amt // parts, + "id": l2_id, + "delay": routehint["cltv_expiry_delta"] + 6, + "channel": chanid, + }, + { + "amount_msat": amt // parts, + "id": l1_id, + "delay": 6, + "channel": routehint["short_channel_id"], + }, + ] + + for partid in range(1, parts + 1): + l3.rpc.sendpay( + route_part, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=partid, + ) def test_lsps_service_disabled(node_factory): @@ -193,25 +282,12 @@ def test_lsps2_buyjitchannel_no_mpp_var_invoice(node_factory, bitcoind): assert l1.rpc.listdatastore(["lsps"]) == {"datastore": []} -def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): - """Tests the creation of a "Just-In-Time-Channel" (jit-channel). - - At the beginning we have the following situation where l2 acts as the LSP - (LSP) - l1 l2----l3 - - l1 now wants to get a channel from l2 via the lsps2 jit-channel protocol: - - l1 requests a new jit channel form l2 - - l1 creates an invoice based on the opening fee parameters it got from l2 - - l3 pays the invoice - - l2 opens a channel to l1 and forwards the payment (deducted by a fee) - - eventualy this will result in the following situation - (LSP) - l1----l2----l3 +def test_lsps2_non_approved_zero_conf(node_factory, bitcoind): + """Checks that we don't allow zerof_conf channels from an LSP if we did + not approve it first. """ - # A mock for lsps2 mpp payments, contains the policy plugin as well. - plugin = os.path.join(os.path.dirname(__file__), "plugins/lsps2_service_mock.py") + # We need a policy service to fetch from. + plugin = os.path.join(os.path.dirname(__file__), "plugins/lsps2_policy.py") l1, l2, l3 = node_factory.get_nodes( 3, @@ -224,7 +300,7 @@ def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): "fee-base": 0, # We are going to deduct our fee anyways, "fee-per-satoshi": 0, # We are going to deduct our fee anyways, }, - {}, + {"disable-mpp": None}, ], ) @@ -234,26 +310,489 @@ def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): node_factory.join_nodes([l3, l2], fundchannel=True, wait_for_announce=True) node_factory.join_nodes([l1, l2], fundchannel=False) + fee_opt = l1.rpc.lsps_lsps2_getinfo(lsp_id=l2.info["id"])[ + "opening_fee_params_menu" + ][0] + buy_res = l1.rpc.lsps_lsps2_buy(lsp_id=l2.info["id"], opening_fee_params=fee_opt) + + hint = [ + [ + { + "id": l2.info["id"], + "short_channel_id": buy_res["jit_channel_scid"], + "fee_base_msat": 0, + "fee_proportional_millionths": 0, + "cltv_expiry_delta": buy_res["lsp_cltv_expiry_delta"], + } + ] + ] + + bolt11 = l1.dev_invoice( + amount_msat="any", + description="lsp-invoice-1", + label="lsp-invoice-1", + dev_routes=hint, + )["bolt11"] + + with pytest.raises(ValueError): + l3.rpc.pay(bolt11, amount_msat=10000000) + + # l1 shouldn't have a new channel. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 0 + + +def test_lsps2_session_mpp_happy_path(node_factory, bitcoind): + """Full MPP happy path through the real session FSM. + + FSM path: Collecting → AwaitingChannelReady → AwaitingSettlement + → Broadcasting → Succeeded + + Exercises SessionSucceeded and FundingBroadcasted events. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 5 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] + + # l1 should have exactly one JIT channel. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + + # Funding tx should eventually be broadcast (session reached Succeeded). + # Mine a block so the funding confirms. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) + ) + + # Datastore should be cleaned up on the client side. + assert l1.rpc.listdatastore(["lsps"]) == {"datastore": []} + + +def test_lsps2_session_mpp_two_parts(node_factory, bitcoind): + """MPP with exactly 2 parts — minimal split. + + Verifies that the session FSM correctly collects and forwards with + small part counts. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] + + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + assert l1.rpc.listdatastore(["lsps"]) == {"datastore": []} + + +def test_lsps2_session_mpp_single_part(node_factory, bitcoind): + """Fixed-amount invoice paid with a single part. + + Even though the payment is a single HTLC, the session path is used + because expected_payment_size is set. Tests the degenerate MPP case. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 1 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] + + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + + +def test_lsps2_session_mpp_collection_timeout(node_factory, bitcoind): + """Partial MPP that never reaches the threshold times out. + + FSM path: Collecting → (timeout) → Failed + + Exercises SessionFailed event. The HTLCs should be failed back with + TEMPORARY_CHANNEL_FAILURE. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + + # Invoice for 10M msat but we'll only send 1 part of 1M. + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + routehint = only_one(only_one(dec["routes"])) + + # Send 1 part out of what should be many — not enough to reach threshold. + route = [ + { + "amount_msat": amt // 10, + "id": l2.info["id"], + "delay": routehint["cltv_expiry_delta"] + 6, + "channel": chanid, + }, + { + "amount_msat": amt // 10, + "id": l1.info["id"], + "delay": 6, + "channel": routehint["short_channel_id"], + }, + ] + + l3.rpc.sendpay( + route, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=1, + ) + + # The session FSM collect timeout (5s in tests). Wait for it to fire. + with pytest.raises(Exception) as exc_info: + l3.rpc.waitsendpay(dec["payment_hash"], partid=1, groupid=1, timeout=30) + # The HTLC should be failed back. + assert ( + "WIRE_TEMPORARY_CHANNEL_FAILURE" in str(exc_info.value) + or exc_info.value is not None + ) + + # No JIT channel should have been created. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 0 + + +def test_lsps2_session_mpp_fundchannel_fails_no_funds(node_factory, bitcoind): + """LSP has no funds to open a channel — fundchannel_start fails. + + FSM path: Collecting → AwaitingChannelReady → FundingFailed → Failed + + All held HTLCs should be failed back. + """ + # Override: do NOT fund the LSP's wallet. + l1, l2, l3 = node_factory.get_nodes( + 3, + opts=[ + {"experimental-lsps-client": None}, + LSP_OPTS, + {}, + ], + ) + + # Fund l3-l2 channel but do NOT fund l2's wallet beyond what join_nodes gives. + node_factory.join_nodes([l3, l2], fundchannel=True, wait_for_announce=True) + node_factory.join_nodes([l1, l2], fundchannel=False) + chanid = only_one(l3.rpc.listpeerchannels(l2.info["id"])["channels"])[ "short_channel_id" ] amt = 10_000_000 - inv = l1.rpc.lsps_lsps2_invoice( - lsp_id=l2.info["id"], - amount_msat=f"{amt}msat", - description="lsp-jit-channel-0", - label="lsp-jit-channel-0", + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # The FSM should try fund_channel, fail (no funds), and fail HTLCs. + with pytest.raises(Exception): + l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) + + # No JIT channel should have been created. + chs = l1.rpc.listpeerchannels(l2.info["id"])["channels"] + assert len(chs) == 0 + + +def test_lsps2_session_mpp_peer_disconnects_before_payment(node_factory, bitcoind): + """Client (l1) disconnects from LSP before payment arrives. + + The fund_channel action should fail because the peer is unreachable. + + FSM path: Collecting → AwaitingChannelReady → FundingFailed → Failed + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Disconnect l1 from l2 before sending payment. + l1.rpc.disconnect(l2.info["id"], force=True) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # fund_channel should fail: peer disconnected. + with pytest.raises(Exception): + l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) + + # No JIT channel. + chs = l1.rpc.listpeerchannels(l2.info["id"])["channels"] + assert len(chs) == 0 + + +def test_lsps2_session_datastore_has_funding_fields(node_factory, bitcoind): + """Verify the LSP's finalized datastore entry contains funding fields. + + After a successful JIT channel session, the LSP (l2) should persist a + finalized entry with channel_id, funding_psbt, and funding_txid populated. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 5 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] + + # Mine a block so the funding confirms and session reaches Succeeded. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) ) - dec = l3.rpc.decode(inv["bolt11"]) - l2.rpc.setuplsps2service( - peer_id=l1.info["id"], channel_cap=100_000, opening_fee_msat=1000_000 + # Wait for the finalized entry to appear on the LSP's datastore. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + + # Read and parse the finalized entry. + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry_raw = only_one(ds["datastore"]) + entry = json.loads(entry_raw["string"]) + + assert entry["outcome"] == "Succeeded" + assert isinstance(entry["channel_id"], str) and entry["channel_id"] + assert isinstance(entry["funding_psbt"], str) and entry["funding_psbt"] + assert isinstance(entry["funding_txid"], str) and entry["funding_txid"] + assert isinstance(entry["preimage"], str) and len(entry["preimage"]) == 64 + + # Active entries should have been cleaned up. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + +def test_lsps2_session_payment_failed_abandoned(node_factory, bitcoind): + """MPP payment fails after HTLCs are forwarded — session ends as Abandoned. + + FSM path: Collecting → AwaitingChannelReady → AwaitingSettlement → Abandoned + + Uses 3 MPP parts so multiple forward_event "failed" notifications hit the + session manager, exercising idempotent cleanup of the dead actor handle. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Delete the invoice on l1 so it can't settle the payment. + # The JIT channel will still be accepted (gated by datastore, not invoice). + invoices = l1.rpc.listinvoices()["invoices"] + for i in invoices: + if i["status"] == "unpaid": + l1.rpc.delinvoice(i["label"], "unpaid") + + parts = 4 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # l1 rejects all parts (no invoice) → forward_event "failed" on l2 → Abandoned. + for partid in range(1, parts + 1): + with pytest.raises(Exception): + l3.rpc.waitsendpay( + dec["payment_hash"], partid=partid, groupid=1, timeout=60 + ) + + # Wait for the finalized entry on l2's datastore. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Abandoned" + + # AbandonSession calls close(unilateraltimeout=1) + unreserveinputs, + # so l2 should have dropped/be closing the channel. + wait_for(lambda: len(l2.rpc.listpeerchannels(l1.info["id"])["channels"]) == 0) + + # unreserveinputs should have freed all UTXOs on the LSP. + assert not any(o["reserved"] for o in l2.rpc.listfunds()["outputs"]) + + +def test_lsps2_session_newblock_unsafe_htlc_timeout(node_factory, bitcoind): + """Partial MPP with low CLTV delay times out when blocks are mined. + + FSM path: Collecting → NewBlock{height > cltv_min} → Failed + + Sends one partial part with a small CLTV delay so that mining a few + blocks triggers UnsafeHtlcTimeout before the 5s collect timeout fires. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) routehint = only_one(only_one(dec["routes"])) - parts = 10 + # Use small delay so cltv_expiry is close to current height. + # The htlc_accepted hook intercepts before CLN's CLTV validation, + # so the small delta is accepted by the LSPS2 plugin. + route = [ + { + "amount_msat": amt // 10, + "id": l2.info["id"], + "delay": 10, + "channel": chanid, + }, + { + "amount_msat": amt // 10, + "id": l1.info["id"], + "delay": 6, + "channel": routehint["short_channel_id"], + }, + ] + + # Send one partial part — not enough to reach threshold, stays in Collecting. + l3.rpc.sendpay( + route, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=1, + ) + + # Mine blocks past cltv_expiry (current_height + 10). + # height becomes current_height + 11 > current_height + 10. + bitcoind.generate_block(11) + + # The HTLC should be failed back by the FSM. + with pytest.raises(Exception): + l3.rpc.waitsendpay(dec["payment_hash"], partid=1, groupid=1, timeout=30) + + # No JIT channel should have been created. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 0 + + # Wait for finalized datastore entry with Failed outcome. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Failed" + + +def test_lsps2_session_cltv_force_close_abandoned(node_factory, bitcoind): + """CLTV deadline force-close triggers Abandoned via channel poll. + + FSM path: Collecting → AwaitingChannelReady → AwaitingSettlement → Abandoned + + l1 holds HTLCs via hold_htlcs. Blocks are mined until l2's outgoing HTLC + CLTV deadline is hit. CLN force-closes the channel. The per-session + listpeerchannels poll detects the channel is no longer CHANNELD_NORMAL + and sends ChannelClosed, transitioning the session to Abandoned. + """ + hold_plugin = os.path.join(os.path.dirname(__file__), "plugins/hold_htlcs.py") + l1, l2, l3, chanid = setup_lsps2_network( + node_factory, + bitcoind, + client_opts={"plugin": hold_plugin, "hold-time": 10000}, + ) + + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # Wait for l1 to hold HTLCs (session in AwaitingSettlement). + l1.daemon.wait_for_log("Holding onto an incoming htlc for 10000 seconds") + + # Mine blocks past CLTV deadline → l2 force-closes JIT channel. + bitcoind.generate_block(8) + l2.daemon.wait_for_log( + r"Peer permanent failure in CHANNELD_NORMAL.*cltv.*hit deadline" + ) + + # Verify: channel poll detects closed channel, FSM reaches Abandoned. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Abandoned" + + # Active session should be cleaned up. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + # Channel should be completely gone on l2. + wait_for(lambda: len(l2.rpc.listpeerchannels(l1.info["id"])["channels"]) == 0) + + # UTXOs should be unreserved and spendable. + assert not any(o["reserved"] for o in l2.rpc.listfunds()["outputs"]) + + # l2 force-closed → HTLCs failed upstream → l3's payment should fail. + for partid in range(1, parts + 1): + with pytest.raises(Exception): + l3.rpc.waitsendpay( + dec["payment_hash"], partid=partid, groupid=1, timeout=60 + ) + + +def test_lsps2_restart_collecting_htlcs_replayed(node_factory, bitcoind): + """Restart during collecting phase — replayed HTLCs create fresh session. + + Recovery path: pre-funding session in datastore → restart → CLN replays + unhandled HTLCs → new session collects and completes successfully. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind, may_reconnect=True) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 5 + routehint = only_one(only_one(dec["routes"])) route_part = [ { "amount_msat": amt // parts, @@ -269,9 +808,10 @@ def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): }, ] - # MPP-payment of fixed amount - for partid in range(1, parts + 1): - r = l3.rpc.sendpay( + for partid in range( + 1, parts + ): # One part is missing to make sure we are actually in CollectingParts state + l3.rpc.sendpay( route_part, dec["payment_hash"], payment_secret=inv["payment_secret"], @@ -280,74 +820,287 @@ def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): groupid=1, partid=partid, ) - assert r - res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + # Restart l2 after 4 of 5 parts arrived — plugin has not funded a channel yet + l2.daemon.wait_for_log(r"PaymentPartAdded.*n_parts: 4") + l2.restart() + l2.connect(l3) + l2.connect(l1) + wait_for( + lambda: ( + only_one(l2.rpc.listpeerchannels(l3.info["id"])["channels"]).get("state") + == "CHANNELD_NORMAL" + ) + ) + + # CLN replays all unhandled HTLCs after restart. The recovery + replay + # should result in a successful payment regardless of how far the + # original session got. Still need to send the last part + l3.rpc.sendpay( + route_part, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=parts, + ) + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) assert res["payment_preimage"] - # l1 should have gotten a jit-channel. + # l1 should have exactly one JIT channel. chs = l1.rpc.listpeerchannels()["channels"] assert len(chs) == 1 - # Check that the client cleaned up after themselves. - assert l1.rpc.listdatastore("lsps") == {"datastore": []} + # Mine a block so the funding confirms. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) + ) + # Finalized entry should show success. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Succeeded" -def test_lsps2_non_approved_zero_conf(node_factory, bitcoind): - """Checks that we don't allow zerof_conf channels from an LSP if we did - not approve it first. + # Active entries should be empty. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + +def test_lsps2_restart_pre_funding_expired_finalized_timeout(node_factory, bitcoind): + """Restart with expired pre-funding session — finalized as Timeout. + + Recovery path: session valid_until has passed, no channel funded → + recovery classifies as Timeout. """ - # We need a policy service to fetch from. - plugin = os.path.join(os.path.dirname(__file__), "plugins/lsps2_policy.py") + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Tamper with the active session's valid_until so recovery sees it as + # expired. This avoids needing a short-validity policy plugin which + # conflicts with the client's 1-minute safety margin. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + ds_entry = only_one(active["datastore"]) + session = json.loads(ds_entry["string"]) + session["opening_fee_params"]["valid_until"] = "2000-01-01T00:00:00.000Z" + l2.rpc.datastore( + key=ds_entry["key"], + string=json.dumps(session), + mode="must-replace", + ) - l1, l2, l3 = node_factory.get_nodes( - 3, - opts=[ - {"experimental-lsps-client": None}, - { - "experimental-lsps2-service": None, - "experimental-lsps2-promise-secret": "0" * 64, - "plugin": plugin, - "fee-base": 0, # We are going to deduct our fee anyways, - "fee-per-satoshi": 0, # We are going to deduct our fee anyways, - }, - {"disable-mpp": None}, - ], + # Restart l2 — recovery finds expired session with no channel. + l2.restart() + + # Recovery should finalize the session as Timeout. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Timeout" - # Give the LSP some funds to open jit-channels - l2.fundwallet(1_000_000) + # Active entries should be cleaned up. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] - node_factory.join_nodes([l3, l2], fundchannel=True, wait_for_announce=True) - node_factory.join_nodes([l1, l2], fundchannel=False) + # No channel should exist between l1 and l2. + chs = l1.rpc.listpeerchannels(l2.info["id"])["channels"] + assert len(chs) == 0 - fee_opt = l1.rpc.lsps_lsps2_getinfo(lsp_id=l2.info["id"])[ - "opening_fee_params_menu" - ][0] - buy_res = l1.rpc.lsps_lsps2_buy(lsp_id=l2.info["id"], opening_fee_params=fee_opt) - hint = [ - [ - { - "id": l2.info["id"], - "short_channel_id": buy_res["jit_channel_scid"], - "fee_base_msat": 0, - "fee_proportional_millionths": 0, - "cltv_expiry_delta": buy_res["lsp_cltv_expiry_delta"], - } - ] - ] +def test_lsps2_restart_awaiting_settlement_payment_completes(node_factory, bitcoind): + """Restart while HTLCs are held — recovered session settles successfully. - bolt11 = l1.dev_invoice( - amount_msat="any", - description="lsp-invoice-1", - label="lsp-invoice-1", - dev_routes=hint, - )["bolt11"] + Recovery path: funded session with OFFERED forwards → recover as + AwaitingSettlement → forward monitoring → payment settles → Succeeded. + """ + hold_plugin = os.path.join(os.path.dirname(__file__), "plugins/hold_htlcs.py") + l1, l2, l3, chanid = setup_lsps2_network( + node_factory, + bitcoind, + client_opts={"plugin": hold_plugin, "hold-time": 15}, + may_reconnect=True, + ) + # JIT channels can trigger bookkeeper "Unable to calculate fees" on restart. + l2.broken_log = r"Unable to calculate fees collected" - with pytest.raises(ValueError): - l3.rpc.pay(bolt11, amount_msat=10000000) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # Wait for l1 to hold HTLCs (channel funded, HTLCs forwarded). + l1.daemon.wait_for_log("Holding onto an incoming htlc for 15 seconds") + + # Confirm early persistence: active session has channel_id. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + ) + > 0 + and json.loads( + only_one( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + )["string"] + ).get("channel_id") + is not None + ) + ) - # l1 shouldn't have a new channel. + # Restart l2 while HTLCs are held on l1. + l2.restart() + l2.connect(l3) + l2.connect(l1) + + # Hold expires → l1 settles → recovered actor detects SETTLED → Succeeded. + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) + assert res["payment_preimage"] + + # l1 should have exactly one JIT channel. chs = l1.rpc.listpeerchannels()["channels"] - assert len(chs) == 0 + assert len(chs) == 1 + + # Mine a block so the funding confirms. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) + ) + + # Finalized entry should show success with funding_txid. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Succeeded" + assert isinstance(entry["funding_txid"], str) and entry["funding_txid"] + + # Active entries should be empty. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + +def test_lsps2_restart_awaiting_settlement_payment_fails_abandoned( + node_factory, bitcoind +): + """Restart while HTLCs are held, payment fails — session Abandoned. + + Recovery path: funded session with OFFERED forwards → recover as + AwaitingSettlement → forward monitoring → forwards fail → Abandoned. + """ + hold_plugin = os.path.join(os.path.dirname(__file__), "plugins/hold_htlcs.py") + l1, l2, l3, chanid = setup_lsps2_network( + node_factory, + bitcoind, + client_opts={"plugin": hold_plugin, "hold-time": 15}, + may_reconnect=True, + ) + # JIT channels can trigger bookkeeper "Unable to calculate fees" on restart. + l2.broken_log = r"Unable to calculate fees collected" + + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Delete the invoice on l1 so it will reject HTLCs after hold expires. + invoices = l1.rpc.listinvoices()["invoices"] + for i in invoices: + if i["status"] == "unpaid": + l1.rpc.delinvoice(i["label"], "unpaid") + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # Wait for l1 to hold HTLCs (channel funded, HTLCs forwarded). + l1.daemon.wait_for_log("Holding onto an incoming htlc for 15 seconds") + + # Confirm early persistence: active session has channel_id. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + ) + > 0 + and json.loads( + only_one( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + )["string"] + ).get("channel_id") + is not None + ) + ) + + # Restart l2 while HTLCs are held. + l2.restart() + l2.connect(l3) + l2.connect(l1) + + # Hold expires → l1 rejects (no invoice) → forwards fail → Abandoned. + for partid in range(1, parts + 1): + with pytest.raises(Exception): + l3.rpc.waitsendpay( + dec["payment_hash"], partid=partid, groupid=1, timeout=60 + ) + + # Finalized entry should show Abandoned. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Abandoned" + + # Channel should be gone on l2. + wait_for(lambda: len(l2.rpc.listpeerchannels(l1.info["id"])["channels"]) == 0) + + # UTXOs should be unreserved. + assert not any(o["reserved"] for o in l2.rpc.listfunds()["outputs"])