diff --git a/Cargo.toml b/Cargo.toml index dda2d03..d13295e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,3 +48,8 @@ openssl = ["dep:openssl"] use-rustls = ["rustls"] use-rustls-ring = ["rustls-ring"] use-openssl = ["openssl"] + +# optional dependencies only used in `jwt_dynamic_auth` example +[dev-dependencies] +tokio = { version = "1", features = ["full"] } +bitreq = { version = "0.3.4", features = ["async-https", "json-using-serde"] } diff --git a/examples/jwt_auth.rs b/examples/jwt_auth.rs new file mode 100644 index 0000000..9045fd8 --- /dev/null +++ b/examples/jwt_auth.rs @@ -0,0 +1,88 @@ +//! # JWT Static Authentication with Electrum Client +//! +//! This example demonstrates how to use a static JWT_TOKEN authentication with the +//! electrum-client library. + +use bitcoin::Txid; +use electrum_client::{Client, ConfigBuilder, ElectrumApi}; +use std::{str::FromStr, sync::Arc}; + +const ELECTRUM_URL: &str = "ssl://electrum.blockstream.info:50002"; + +const GENESIS_HEIGHT: usize = 0; +const GENESIS_TXID: &str = "4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b"; + +fn main() { + // A static JWT_TOKEN (i.e JWT_TOKEN="Bearer jwt_token...") + let auth_provider = Arc::new(move || { + let jwt_token = std::env::var("JWT_TOKEN").expect("JWT_TOKEN env variable not set"); + Some(jwt_token) + }); + + // The Electrum Server URL (i.e `ELECTRUM_URL` environment variable, or defaults to `ELECTRUM_URL` const above) + let electrum_url = std::env::var("ELECTRUM_URL").unwrap_or(ELECTRUM_URL.to_owned()); + + // Builds the electrum-client `Config`. + let config = ConfigBuilder::new() + .validate_domain(false) + .authorization_provider(Some(auth_provider)) + .build(); + + // Builds & Connect electrum-client `Client`. + match Client::from_config(&electrum_url, config) { + Ok(client) => { + println!( + "Successfully connected to Electrum Server: {:#?}; with JWT authentication!", + electrum_url + ); + + // try to call the `server.features` method, it can fail on some servers. + match client.server_features() { + Ok(features) => println!( + "Successfully fetched the `server.features`!\n{:#?}", + features + ), + Err(e) => eprintln!("Failed to fetch the `server.features`!\nError: {:#?}", e), + } + + // try to call the `blockchain.block.header` method, it should NOT fail. + let genesis_height = GENESIS_HEIGHT; + match client.block_header(genesis_height) { + Ok(header) => { + println!( + "Successfully fetched the `Header` for given `height`={}!\n{:#?}", + genesis_height, header + ); + } + Err(err) => eprintln!( + "Failed to fetch the `Header` for given `height`!\nError: {:#?}", + err + ), + } + + // try to call the `blockchain.transaction.get` method, it should NOT fail. + let genesis_txid = + Txid::from_str(GENESIS_TXID).expect("SHOULD have a valid genesis `txid`"); + match client.transaction_get(&genesis_txid) { + Ok(tx) => { + println!( + "Successfully fetched the `Transaction` for given `txid`={}!\n{:#?}", + genesis_txid, tx + ); + } + Err(err) => eprintln!( + "Failed to fetch the `Transaction` for given `txid`!\nError: {:#?}", + err + ), + } + } + Err(err) => { + eprintln!( + "Failed to build and connect `Client` to {:#?}!\nError: {:#?}\n", + electrum_url, err + ); + eprintln!("NOTE: This example requires an Electrum Server that handles/accept JWT authentication!"); + eprintln!("Try to update the `ELECTRUM_URL` and `JWT_TOKEN to match your setup."); + } + } +} diff --git a/examples/jwt_dynamic_auth.rs b/examples/jwt_dynamic_auth.rs new file mode 100644 index 0000000..1fed314 --- /dev/null +++ b/examples/jwt_dynamic_auth.rs @@ -0,0 +1,183 @@ +//! # JWT Dynamic Authentication +//! +//! ## Advanced: Token Refresh with Keycloak +//! +//! This example demonstrates how to use dynamic JWT authentication with the +//! electrum-client library. +//! +//! ## Overview +//! +//! The electrum-client supports embedding authorization tokens (such as JWT +//! Bearer tokens) directly in JSON-RPC requests. This is achieved through an +//! [`AuthProvider`](electrum_client::config::AuthProvider) callback that is +//! invoked before each request. +//! +//! In order to have an automatic token refresh (e.g it expires every 5 minutes), +//! you should use a shared token holder (e.g KeycloakTokenManager) +//! behind an `Arc>` and spawn a background task to refresh it. +//! +//! ## JSON-RPC Request Format +//! +//! With the auth provider configured, each JSON-RPC request will include the +//! authorization field: +//! +//! ```json +//! { +//! "jsonrpc": "2.0", +//! "method": "blockchain.headers.subscribe", +//! "params": [], +//! "id": 1, +//! "authorization": "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9..." +//! } +//! ``` +//! +//! If the provider returns `None`, the authorization field is omitted from the +//! request. +//! +//! ## Thread Safety +//! +//! The `AuthProvider` type is defined as: +//! +//! ```rust,ignore +//! pub type AuthProvider = Arc Option + Send + Sync>; +//! ``` +//! +//! This ensures thread-safe access to tokens across all RPC calls. + +use electrum_client::{Client, ConfigBuilder, ElectrumApi}; +use std::sync::{Arc, RwLock}; +use std::time::Duration; +use tokio::time::sleep; + +/// Manages JWT tokens from Keycloak with automatic refresh +struct KeycloakTokenManager { + token: Arc>>, + keycloak_url: String, + grant_type: String, + client_id: String, + client_secret: String, +} + +impl KeycloakTokenManager { + fn new( + keycloak_url: String, + grant_type: String, + client_id: String, + client_secret: String, + ) -> Self { + Self { + token: Arc::new(RwLock::new(None)), + keycloak_url, + client_id, + client_secret, + grant_type, + } + } + + /// Get the current token (for the auth provider) + fn get_token(&self) -> Option { + self.token.read().unwrap().clone() + } + + /// Fetch a fresh token from Keycloak + async fn fetch_token(&self) -> Result> { + let url = format!("{}/protocol/openid-connect/token", self.keycloak_url); + + // if you're using other HTTP client (i.e `reqwest`), you can probably use `.form` methods. + // it's currently not implemented in `bitreq`, needs to be built manually. + let body = format!( + "grant_type={}&client_id={}&client_secret={}", + self.grant_type, self.client_id, self.client_secret + ); + + let response = bitreq::post(url) + .with_header("Content-Type", "application/x-www-form-urlencoded") + .with_body(body) + .send_async() + .await?; + + let json: serde_json::Value = response.json()?; + let access_token = json["access_token"] + .as_str() + .ok_or("Missing access_token")? + .to_string(); + + Ok(format!("Bearer {}", access_token)) + } + + /// Background task that refreshes the token every 4 minutes + async fn refresh_loop(self: Arc) { + loop { + // Refresh every 4 minutes (tokens expire at 5 minutes) + sleep(Duration::from_secs(240)).await; + + match self.fetch_token().await { + Ok(new_token) => { + println!("Token refreshed successfully"); + // In a background thread/task, periodically update the token + *self.token.write().unwrap() = Some(new_token); + } + Err(e) => { + eprintln!("Failed to refresh token: {}", e); + // Keep using old token until we can refresh + } + } + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // The Electrum Server URL (i.e `ELECTRUM_URL` environment variable) + let electrum_url = std::env::var("ELECTRUM_URL") + .expect("SHOULD have the `ELECTRUM_URL` environment variable!"); + + // The JWT_TOKEN manager setup (i.e Keycloak server URL, client ID and secret) + let keycloak_url = std::env::var("KEYCLOAK_URL") + .expect("SHOULD have the `KEYCLOAK_URL` environment variable!"); + + let grant_type = std::env::var("GRANT_TYPE").unwrap_or("client_credentials".to_string()); + let client_id = + std::env::var("CLIENT_ID").expect("SHOULD have the `CLIENT_ID` environment variable!"); + let client_secret = std::env::var("CLIENT_SECRET") + .expect("SHOULD have the `CLIENT_SECRET` environment variable!"); + + // Setup `KeycloakTokenManager` + let token_manager = Arc::new(KeycloakTokenManager::new( + keycloak_url, + grant_type, + client_id, + client_secret, + )); + + // Fetch initial token + let jwt_token = token_manager.fetch_token().await?; + + println!("JWT_TOKEN='{}'", &jwt_token[..jwt_token.len().min(40)]); + + *token_manager.token.write().unwrap() = Some(jwt_token); + + // Start background refresh task + let tm_clone = token_manager.clone(); + tokio::spawn(async move { + tm_clone.refresh_loop().await; + }); + + // Create Electrum client with dynamic auth provider + let tm_for_provider = token_manager.clone(); + let config = ConfigBuilder::new() + .authorization_provider(Some(Arc::new(move || tm_for_provider.get_token()))) + .build(); + + let client = Client::from_config(&electrum_url, config)?; + + // All RPC calls will automatically include fresh JWT tokens + loop { + match client.server_features() { + Ok(features) => println!("Connected: {:?}", features), + Err(e) => eprintln!("Error: {}", e), + } + + tokio::time::sleep(Duration::from_secs(10)).await; + } +} diff --git a/src/client.rs b/src/client.rs index cb59a4d..4f9e457 100644 --- a/src/client.rs +++ b/src/client.rs @@ -112,26 +112,36 @@ impl ClientType { /// Constructor that supports multiple backends and allows configuration through /// the [Config] pub fn from_config(url: &str, config: &Config) -> Result { + let auth_provider = config.authorization_provider().cloned(); + #[cfg(any(feature = "openssl", feature = "rustls", feature = "rustls-ring"))] if url.starts_with("ssl://") { let url = url.replacen("ssl://", "", 1); #[cfg(feature = "proxy")] - let client = match config.socks5() { + let raw_client = match config.socks5() { Some(socks5) => RawClient::new_proxy_ssl( url.as_str(), config.validate_domain(), socks5, config.timeout(), + auth_provider, + )?, + None => RawClient::new_ssl( + url.as_str(), + config.validate_domain(), + config.timeout(), + auth_provider, )?, - None => { - RawClient::new_ssl(url.as_str(), config.validate_domain(), config.timeout())? - } }; #[cfg(not(feature = "proxy"))] - let client = - RawClient::new_ssl(url.as_str(), config.validate_domain(), config.timeout())?; - - return Ok(ClientType::SSL(client)); + let raw_client = RawClient::new_ssl( + url.as_str(), + config.validate_domain(), + config.timeout(), + auth_provider, + )?; + + return Ok(ClientType::SSL(raw_client)); } #[cfg(not(any(feature = "openssl", feature = "rustls", feature = "rustls-ring")))] @@ -143,18 +153,28 @@ impl ClientType { { let url = url.replacen("tcp://", "", 1); + #[cfg(feature = "proxy")] let client = match config.socks5() { Some(socks5) => ClientType::Socks5(RawClient::new_proxy( url.as_str(), socks5, config.timeout(), + auth_provider, + )?), + None => ClientType::TCP(RawClient::new( + url.as_str(), + config.timeout(), + auth_provider, )?), - None => ClientType::TCP(RawClient::new(url.as_str(), config.timeout())?), }; #[cfg(not(feature = "proxy"))] - let client = ClientType::TCP(RawClient::new(url.as_str(), config.timeout())?); + let client = ClientType::TCP(RawClient::new( + url.as_str(), + config.timeout(), + auth_provider, + )?); Ok(client) } diff --git a/src/config.rs b/src/config.rs index e4c5770..b3e6d7a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,12 +1,16 @@ +use std::sync::Arc; use std::time::Duration; +/// A function that provides authorization tokens dynamically (e.g., for JWT refresh) +pub type AuthProvider = Arc Option + Send + Sync>; + /// Configuration for an electrum client /// /// Refer to [`Client::from_config`] and [`ClientType::from_config`]. /// /// [`Client::from_config`]: crate::Client::from_config /// [`ClientType::from_config`]: crate::ClientType::from_config -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Config { /// Proxy socks5 configuration, default None socks5: Option, @@ -16,6 +20,24 @@ pub struct Config { retry: u8, /// when ssl, validate the domain, default true validate_domain: bool, + /// Optional authorization provider for dynamic token injection + authorization_provider: Option, +} + +// Custom Debug impl because AuthProvider doesn't implement Debug +impl std::fmt::Debug for Config { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Config") + .field("socks5", &self.socks5) + .field("timeout", &self.timeout) + .field("retry", &self.retry) + .field("validate_domain", &self.validate_domain) + .field( + "authorization_provider", + &self.authorization_provider.as_ref().map(|_| ""), + ) + .finish() + } } /// Configuration for Socks5 @@ -72,6 +94,12 @@ impl ConfigBuilder { self } + /// Sets the authorization provider for dynamic token injection + pub fn authorization_provider(mut self, provider: Option) -> Self { + self.config.authorization_provider = provider; + self + } + /// Return the config and consume the builder pub fn build(self) -> Config { self.config @@ -131,6 +159,13 @@ impl Config { self.validate_domain } + /// Get the configuration for `authorization_provider` + /// + /// Set this with [`ConfigBuilder::authorization_provider`] + pub fn authorization_provider(&self) -> Option<&AuthProvider> { + self.authorization_provider.as_ref() + } + /// Convenience method for calling [`ConfigBuilder::new`] pub fn builder() -> ConfigBuilder { ConfigBuilder::new() @@ -144,6 +179,106 @@ impl Default for Config { timeout: None, retry: 1, validate_domain: true, + authorization_provider: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_authorization_provider_builder() { + let token = "test-token-123".to_string(); + let provider = Arc::new(move || Some(format!("Bearer {}", token))); + + let config = ConfigBuilder::new() + .authorization_provider(Some(provider.clone())) + .build(); + + assert!(config.authorization_provider().is_some()); + + // Test that the provider returns the expected value + if let Some(auth_provider) = config.authorization_provider() { + assert_eq!(auth_provider(), Some("Bearer test-token-123".to_string())); } } + + #[test] + fn test_authorization_provider_none() { + let config = ConfigBuilder::new().build(); + + assert!(config.authorization_provider().is_none()); + } + + #[test] + fn test_authorization_provider_returns_none() { + let provider = Arc::new(|| None); + + let config = ConfigBuilder::new() + .authorization_provider(Some(provider)) + .build(); + + assert!(config.authorization_provider().is_some()); + + // Test that the provider returns None + if let Some(auth_provider) = config.authorization_provider() { + assert_eq!(auth_provider(), None); + } + } + + #[test] + fn test_authorization_provider_dynamic_token() { + use std::sync::RwLock; + + // Simulate a token that can be updated + let token = Arc::new(RwLock::new("initial-token".to_string())); + let token_clone = token.clone(); + + let provider = Arc::new(move || Some(token_clone.read().unwrap().clone())); + + let config = ConfigBuilder::new() + .authorization_provider(Some(provider.clone())) + .build(); + + // Initial token + if let Some(auth_provider) = config.authorization_provider() { + assert_eq!(auth_provider(), Some("initial-token".to_string())); + } + + // Update the token + *token.write().unwrap() = "refreshed-token".to_string(); + + // Provider should return the new token + if let Some(auth_provider) = config.authorization_provider() { + assert_eq!(auth_provider(), Some("refreshed-token".to_string())); + } + } + + #[test] + fn test_config_debug_with_provider() { + let provider = Arc::new(|| Some("secret-token".to_string())); + + let config = ConfigBuilder::new() + .authorization_provider(Some(provider)) + .build(); + + let debug_str = format!("{:?}", config); + + // Should show instead of the actual function pointer + assert!(debug_str.contains("")); + // Should not leak the token value + assert!(!debug_str.contains("secret-token")); + } + + #[test] + fn test_config_debug_without_provider() { + let config = ConfigBuilder::new().build(); + + let debug_str = format!("{:?}", config); + + // Should show None for authorization_provider + assert!(debug_str.contains("authorization_provider")); + } } diff --git a/src/lib.rs b/src/lib.rs index d916b12..94ff0aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,5 +58,5 @@ pub mod utils; pub use api::ElectrumApi; pub use batch::Batch; pub use client::*; -pub use config::{Config, ConfigBuilder, Socks5Config}; +pub use config::{AuthProvider, Config, ConfigBuilder, Socks5Config}; pub use types::*; diff --git a/src/raw_client.rs b/src/raw_client.rs index 68f5a63..daa9fe6 100644 --- a/src/raw_client.rs +++ b/src/raw_client.rs @@ -37,6 +37,7 @@ use crate::stream::ClonableStream; use crate::api::ElectrumApi; use crate::batch::Batch; +use crate::config::AuthProvider; use crate::types::*; /// Client name sent to the server during protocol version negotiation. @@ -132,17 +133,16 @@ impl_to_socket_addrs_domain!((std::net::Ipv6Addr, u16)); /// Instance of an Electrum client /// -/// A `Client` maintains a constant connection with an Electrum server and exposes methods to -/// interact with it. It can also subscribe and receive notifictations from the server about new +/// A [`RawClient`] maintains a constant connection with an Electrum server and exposes methods to +/// interact with it. It can also subscribe and receive notifications from the server about new /// blocks or activity on a specific *scriptPubKey*. /// -/// The `Client` is modeled in such a way that allows the external caller to have full control over +/// The [`RawClient`] is modeled in such a way that allows the external caller to have full control over /// its functionality: no threads or tasks are spawned internally to monitor the state of the /// connection. /// /// More transport methods can be used by manually creating an instance of this struct with an -/// arbitray `S` type. -#[derive(Debug)] +/// arbitrary `S` type. pub struct RawClient where S: Read + Write, @@ -159,9 +159,33 @@ where /// The protocol version negotiated with the server via `server.version`. protocol_version: Mutex>, + /// Optional authorization provider for dynamic token injection (e.g., JWT). + auth_provider: Option, + calls: AtomicUsize, } +// Custom Debug impl because AuthProvider doesn't implement Debug +impl std::fmt::Debug for RawClient +where + S: Read + Write, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RawClient") + .field("stream", &"") + .field("buf_reader", &"") + .field("last_id", &self.last_id) + .field("waiting_map", &self.waiting_map) + .field("headers", &self.headers) + .field("script_notifications", &self.script_notifications) + .field( + "auth_provider", + &self.auth_provider.as_ref().map(|_| ""), + ) + .finish() + } +} + impl From for RawClient where S: Read + Write, @@ -181,6 +205,8 @@ where protocol_version: Mutex::new(None), + auth_provider: None, + calls: AtomicUsize::new(0), } } @@ -196,6 +222,7 @@ impl RawClient { pub fn new( socket_addrs: A, timeout: Option, + auth_provider: Option, ) -> Result { let stream = match timeout { Some(timeout) => { @@ -207,8 +234,10 @@ impl RawClient { None => TcpStream::connect(socket_addrs)?, }; - let client: Self = stream.into(); - client.negotiate_protocol_version()?; + let client = Self::from(stream) + .with_auth(auth_provider) + .negotiate_protocol_version()?; + Ok(client) } } @@ -261,6 +290,7 @@ impl RawClient { socket_addrs: A, validate_domain: bool, timeout: Option, + auth_provider: Option, ) -> Result { debug!( "new_ssl socket_addrs.domain():{:?} validate_domain:{} timeout:{:?}", @@ -276,11 +306,11 @@ impl RawClient { let stream = connect_with_total_timeout(socket_addrs.clone(), timeout)?; stream.set_read_timeout(Some(timeout))?; stream.set_write_timeout(Some(timeout))?; - Self::new_ssl_from_stream(socket_addrs, validate_domain, stream) + Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider) } None => { let stream = TcpStream::connect(socket_addrs.clone())?; - Self::new_ssl_from_stream(socket_addrs, validate_domain, stream) + Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider) } } } @@ -290,9 +320,11 @@ impl RawClient { socket_addrs: A, validate_domain: bool, stream: TcpStream, + auth_provider: Option, ) -> Result { let mut builder = SslConnector::builder(SslMethod::tls()).map_err(Error::InvalidSslMethod)?; + // TODO: support for certificate pinning if validate_domain { socket_addrs.domain().ok_or(Error::MissingDomain)?; @@ -307,8 +339,10 @@ impl RawClient { .connect(&domain, stream) .map_err(Error::SslHandshakeError)?; - let client: Self = stream.into(); - client.negotiate_protocol_version()?; + let client = Self::from(stream) + .with_auth(auth_provider) + .negotiate_protocol_version()?; + Ok(client) } } @@ -384,6 +418,7 @@ impl RawClient { socket_addrs: A, validate_domain: bool, timeout: Option, + auth_provider: Option, ) -> Result { debug!( "new_ssl socket_addrs.domain():{:?} validate_domain:{} timeout:{:?}", @@ -391,19 +426,21 @@ impl RawClient { validate_domain, timeout ); + if validate_domain { socket_addrs.domain().ok_or(Error::MissingDomain)?; } + match timeout { Some(timeout) => { let stream = connect_with_total_timeout(socket_addrs.clone(), timeout)?; stream.set_read_timeout(Some(timeout))?; stream.set_write_timeout(Some(timeout))?; - Self::new_ssl_from_stream(socket_addrs, validate_domain, stream) + Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider) } None => { let stream = TcpStream::connect(socket_addrs.clone())?; - Self::new_ssl_from_stream(socket_addrs, validate_domain, stream) + Self::new_ssl_from_stream(socket_addrs, validate_domain, stream, auth_provider) } } } @@ -413,6 +450,7 @@ impl RawClient { socket_addr: A, validate_domain: bool, tcp_stream: TcpStream, + auth_provider: Option, ) -> Result { use std::convert::TryFrom; @@ -476,8 +514,10 @@ impl RawClient { .map_err(Error::CouldNotCreateConnection)?; let stream = StreamOwned::new(session, tcp_stream); - let client: Self = stream.into(); - client.negotiate_protocol_version()?; + let client = Self::from(stream) + .with_auth(auth_provider) + .negotiate_protocol_version()?; + Ok(client) } } @@ -494,6 +534,7 @@ impl RawClient { target_addr: T, proxy: &crate::Socks5Config, timeout: Option, + auth_provider: Option, ) -> Result { let mut stream = match proxy.credentials.as_ref() { Some(cred) => Socks5Stream::connect_with_password( @@ -508,8 +549,10 @@ impl RawClient { stream.get_mut().set_read_timeout(timeout)?; stream.get_mut().set_write_timeout(timeout)?; - let client: Self = stream.into(); - client.negotiate_protocol_version()?; + let client = Self::from(stream) + .with_auth(auth_provider) + .negotiate_protocol_version()?; + Ok(client) } @@ -525,6 +568,7 @@ impl RawClient { validate_domain: bool, proxy: &crate::Socks5Config, timeout: Option, + auth_provider: Option, ) -> Result, Error> { let target = target_addr.to_target_addr()?; @@ -538,10 +582,11 @@ impl RawClient { )?, None => Socks5Stream::connect(&proxy.addr, target.clone(), timeout)?, }; + stream.get_mut().set_read_timeout(timeout)?; stream.get_mut().set_write_timeout(timeout)?; - RawClient::new_ssl_from_stream(target, validate_domain, stream.into_inner()) + RawClient::new_ssl_from_stream(target, validate_domain, stream.into_inner(), auth_provider) } } @@ -560,15 +605,39 @@ impl RawClient { // `ClonableStream` before other threads can send a request to the server. They will block // waiting for the reader to release the mutex, but this will never happen because the server // didn't receive any request, so it has nothing to send back. + // // pub fn reader_thread(&self) -> Result<(), Error> { // self._reader_thread(None).map(|_| ()) // } - /// Negotiates the protocol version with the server. + /// Sets the [`AuthProvider`] for this client, enabling authentication on all + /// outgoing RPC requests. + /// + /// The `auth_provider` is a callback invoked before each request, allowing + /// dynamic token strategies such as automatic JWT refresh without + /// reconnecting the client. Passing `None` or not calling this method + /// disables authentication. + /// + /// # Notes + /// + /// This method should be called **before** [`RawClient::negotiate_protocol_version`], + /// as the initial `server.version` handshake also requires authentication + /// on protected servers. + fn with_auth(mut self, auth_provider: Option) -> Self { + self.auth_provider = auth_provider; + self + } + + /// Negotiates the Electrum protocol version with the Electrum server. /// /// This sends `server.version` as the first message and stores the negotiated - /// protocol version. Called automatically by constructors. - fn negotiate_protocol_version(&self) -> Result<(), Error> { + /// protocol version. + /// + /// As of Electrum Protocol v1.6 it's a mandatory step, see: + /// + /// + /// [`ClientType`]: crate::ClientType + fn negotiate_protocol_version(self) -> Result { let version_range = vec![ PROTOCOL_VERSION_MIN.to_string(), PROTOCOL_VERSION_MAX.to_string(), @@ -585,7 +654,7 @@ impl RawClient { let response: ServerVersionRes = serde_json::from_value(result)?; *self.protocol_version.lock()? = Some(response.protocol_version); - Ok(()) + Ok(self) } fn _reader_thread(&self, until_message: Option) -> Result { @@ -715,6 +784,14 @@ impl RawClient { let (sender, receiver) = channel(); self.waiting_map.lock()?.insert(req.id, sender); + // apply `authorization` token into `Request`, if any. + let authorization = self + .auth_provider + .as_ref() + .and_then(|auth_provider| auth_provider()); + + let req = req.with_auth(authorization); + let mut raw = serde_json::to_vec(&req)?; trace!("==> {}", String::from_utf8_lossy(&raw)); @@ -832,12 +909,31 @@ impl ElectrumApi for RawClient { // Add our listener to the map before we send the request - for (method, params) in batch.iter() { - let req = Request::new_id( + for (idx, (method, params)) in batch.iter().enumerate() { + let mut req = Request::new_id( self.last_id.fetch_add(1, Ordering::SeqCst), method, params.to_vec(), ); + + // Although the library DOES NOT use JSON-RPC batch arrays, + // It applies the `authorization` ONLY in the first `Request` of the `Batch`. + // + // JWT tokens can be 1KB+, therefore duplicating it across multiple requests adds significant overhead. + // It assumes the server authenticates the `Batch` by the first `Request`. If a server implementation treats + // each newline-delimited request independently, subsequently `Request`'s would be unauthenticated. + // + // It's a known trade-off, not a bug. + if idx == 0 { + // it should get the `authorization`, if there's an `auth_provider` available. + let authorization = self + .auth_provider + .as_ref() + .and_then(|auth_provider| auth_provider()); + + req = req.with_auth(authorization); + } + // Add distinct channel to each request so when we remove our request id (and sender) from the waiting_map // we can be sure that the response gets sent to the correct channel in self.recv let (sender, receiver) = channel(); @@ -1315,11 +1411,30 @@ mod test { use super::{ElectrumSslStream, RawClient}; use crate::api::ElectrumApi; + use crate::config::AuthProvider; + + // it's the default live testing electrum server, if you'd like to use a custom one set it up through + // the environment variable `TEST_ELECTRUM_SERVER`. + // + // here's an useful list of live servers: https://1209k.com/bitcoin-eye/ele.php. + const DEFAULT_TEST_ELECTRUM_SERVER: &str = "fortress.qtornado.com:443"; + + fn get_test_auth_client( + authorization_provider: Option, + ) -> RawClient { + let server = std::env::var("TEST_ELECTRUM_SERVER") + .unwrap_or(DEFAULT_TEST_ELECTRUM_SERVER.to_owned()); + + RawClient::new_ssl(&*server, false, None, authorization_provider) + .expect("should build the `RawClient` successfully!") + } fn get_test_client() -> RawClient { - let server = - std::env::var("TEST_ELECTRUM_SERVER").unwrap_or("fortress.qtornado.com:443".into()); - RawClient::new_ssl(&*server, false, None).unwrap() + let server = std::env::var("TEST_ELECTRUM_SERVER") + .unwrap_or(DEFAULT_TEST_ELECTRUM_SERVER.to_owned()); + + RawClient::new_ssl(&*server, false, None, None) + .expect("should build the `RawClient` successfully!") } #[test] @@ -1800,4 +1915,78 @@ mod test { 00000" ) } + + #[test] + fn test_authorization_provider_with_client() { + use std::sync::{Arc, RwLock}; + + // Track how many times the provider is called + let call_count = Arc::new(RwLock::new(0)); + let call_count_clone = call_count.clone(); + + let auth_provider = Arc::new(move || { + *call_count_clone.write().unwrap() += 1; + Some("Bearer test-token-123".to_string()) + }); + + let client = get_test_auth_client(Some(auth_provider)); + + // Make a request - provider should be called + let _ = client.server_features(); + + // Provider should have been called at least once + assert!(*call_count.read().unwrap() >= 1); + } + + #[test] + fn test_authorization_provider_dynamic_token_refresh() { + use std::sync::{Arc, RwLock}; + + // Simulate a token that can be refreshed + let token = Arc::new(RwLock::new("initial-token".to_string())); + let token_clone = token.clone(); + + let auth_provider: AuthProvider = + Arc::new(move || Some(token_clone.read().unwrap().clone())); + + let client = get_test_auth_client(Some(auth_provider.clone())); + + // Make first request with initial token + let _ = client.server_features(); + + // Simulate token refresh + *token.write().unwrap() = "refreshed-token".to_string(); + + // Make second request - should use the new token + let _ = client.server_features(); + + // Verify the provider now returns the refreshed token + assert_eq!(auth_provider(), Some("refreshed-token".to_string())); + } + + #[test] + fn test_authorization_provider_returns_none() { + use std::sync::Arc; + + let auth_provider: AuthProvider = Arc::new(|| None); + + let client = get_test_auth_client(Some(auth_provider)); + + // Should still work when provider returns None + let result = client.server_features(); + assert!(result.is_ok()); + } + + #[test] + fn test_auth_provider_via_constructor() { + use std::sync::Arc; + + let auth_provider: AuthProvider = Arc::new(|| Some("Bearer test".to_string())); + + let client = get_test_auth_client(Some(auth_provider)); + + // Verify the provider was set + let result = client.server_features(); + assert!(result.is_ok()); + } } diff --git a/src/types.rs b/src/types.rs index 0714314..d8352be 100644 --- a/src/types.rs +++ b/src/types.rs @@ -71,26 +71,37 @@ pub struct Request<'a> { pub method: &'a str, /// The request parameters pub params: Vec, + + /// Authorization token (e.g. `"Bearer "`) included in the JSON-RPC request, if any. + #[serde(skip_serializing_if = "Option::is_none")] + authorization: Option, } impl<'a> Request<'a> { - /// Creates a new request with a default id + /// Creates a new [`Request`] with a default `id`. fn new(method: &'a str, params: Vec) -> Self { Self { id: 0, jsonrpc: JSONRPC_2_0, method, params, + authorization: None, } } - /// Creates a new request with a user-specified id + /// Creates a new [`Request`] with a user-specified `id`. pub fn new_id(id: usize, method: &'a str, params: Vec) -> Self { let mut instance = Self::new(method, params); instance.id = id; instance } + + /// Sets the `authorization` token for this [`Request`]. + pub fn with_auth(mut self, authorization: Option) -> Self { + self.authorization = authorization; + self + } } #[doc(hidden)] @@ -518,6 +529,8 @@ impl From for Error { mod tests { use crate::ScriptStatus; + use super::{Param, Request}; + #[test] fn script_status_roundtrip() { let script_status: ScriptStatus = [1u8; 32].into(); @@ -525,4 +538,59 @@ mod tests { let script_status_back = serde_json::from_str(&script_status_json).unwrap(); assert_eq!(script_status, script_status_back); } + + #[test] + fn test_request_serialization_without_authorization() { + let req = Request::new_id(1, "server.version", vec![]); + + let json = serde_json::to_string(&req).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + + // Authorization field should not be present when None + assert!(parsed.get("authorization").is_none()); + assert!(!json.contains("authorization")); + assert_eq!(parsed["jsonrpc"], "2.0"); + assert_eq!(parsed["method"], "server.version"); + assert_eq!(parsed["id"], 1); + } + + #[test] + fn test_request_serialization_with_authorization() { + let mut req = Request::new_id(1, "server.version", vec![]); + req.authorization = Some("Bearer test-jwt-token".to_string()); + + let json = serde_json::to_string(&req).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + + // Authorization field should be present + assert_eq!( + parsed["authorization"], + serde_json::Value::String("Bearer test-jwt-token".to_string()) + ); + assert_eq!(parsed["jsonrpc"], "2.0"); + assert_eq!(parsed["method"], "server.version"); + assert_eq!(parsed["id"], 1); + } + + #[test] + fn test_request_with_params_and_authorization() { + let mut req = Request::new_id( + 42, + "blockchain.scripthash.get_balance", + vec![Param::String("test-scripthash".to_string())], + ); + req.authorization = Some("Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9".to_string()); + + let json = serde_json::to_string(&req).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed["id"], 42); + assert_eq!(parsed["method"], "blockchain.scripthash.get_balance"); + assert_eq!( + parsed["authorization"], + "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9" + ); + assert!(parsed["params"].is_array()); + assert_eq!(parsed["params"][0], "test-scripthash"); + } }