From a09ec69f38b9401ba2efec7561ba8a005c039689 Mon Sep 17 00:00:00 2001 From: Borja Castellano Date: Tue, 10 Mar 2026 18:58:12 +0000 Subject: [PATCH] restrict ServiceFlags api --- dash-spv/src/network/handshake.rs | 12 ++-- dash-spv/src/network/peer.rs | 17 ++--- dash/src/network/address.rs | 35 ++++++--- dash/src/network/constants.rs | 113 ++++++++++++------------------ dash/src/network/message.rs | 24 +++---- 5 files changed, 92 insertions(+), 109 deletions(-) diff --git a/dash-spv/src/network/handshake.rs b/dash-spv/src/network/handshake.rs index 215deb2fd..395eab051 100644 --- a/dash-spv/src/network/handshake.rs +++ b/dash-spv/src/network/handshake.rs @@ -4,7 +4,7 @@ use std::net::SocketAddr; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use dashcore::network::constants; -use dashcore::network::constants::{ServiceFlags, NODE_HEADERS_COMPRESSED}; +use dashcore::network::constants::ServiceFlags; use dashcore::network::message::NetworkMessage; use dashcore::network::message_network::VersionMessage; use dashcore::Network; @@ -36,7 +36,7 @@ pub struct HandshakeManager { state: HandshakeState, our_version: u32, peer_version: Option, - peer_services: Option, + peer_services: ServiceFlags, version_received: bool, verack_received: bool, version_sent: bool, @@ -56,7 +56,7 @@ impl HandshakeManager { state: HandshakeState::Init, our_version: constants::PROTOCOL_VERSION, peer_version: None, - peer_services: None, + peer_services: ServiceFlags::NONE, version_received: false, verack_received: false, version_sent: false, @@ -157,7 +157,7 @@ impl HandshakeManager { version_msg ); self.peer_version = Some(version_msg.version); - self.peer_services = Some(version_msg.services); + self.peer_services = version_msg.services; self.version_received = true; // Update connection's peer information @@ -261,7 +261,7 @@ impl HandshakeManager { .as_secs() as i64; // Advertise headers2 support (NODE_HEADERS_COMPRESSED) - let services = ServiceFlags::NONE | NODE_HEADERS_COMPRESSED; + let services = ServiceFlags::NODE_HEADERS_COMPRESSED; // Parse the local address safely let local_addr = "127.0.0.1:0" @@ -313,7 +313,7 @@ impl HandshakeManager { /// Check if peer supports headers2 compression. pub fn peer_supports_headers2(&self) -> bool { - self.peer_services.map(|services| services.has(NODE_HEADERS_COMPRESSED)).unwrap_or(false) + self.peer_services.has(ServiceFlags::NODE_HEADERS_COMPRESSED) } /// Negotiate headers2 support with the peer after handshake completion. diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index dc3e484ed..1d7f39c3c 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -40,7 +40,7 @@ pub struct Peer { pending_pings: HashMap, // nonce -> sent_time // Peer information from Version message version: Option, - services: Option, + services: ServiceFlags, user_agent: Option, best_height: Option, relay: Option, @@ -68,7 +68,7 @@ impl Peer { last_pong_received: None, pending_pings: HashMap::new(), version: None, - services: None, + services: ServiceFlags::NONE, user_agent: None, best_height: None, relay: None, @@ -115,7 +115,7 @@ impl Peer { last_pong_received: None, pending_pings: HashMap::new(), version: None, - services: None, + services: ServiceFlags::NONE, user_agent: None, best_height: None, relay: None, @@ -144,7 +144,7 @@ impl Peer { } pub fn has_service(&self, flags: ServiceFlags) -> bool { - self.services.map(|s| ServiceFlags::from(s).has(flags)).unwrap_or(false) + self.services.has(flags) } /// Connect to the peer (instance method for compatibility). @@ -273,7 +273,7 @@ impl Peer { // All validations passed, update peer info self.version = Some(version_msg.version); - self.services = Some(version_msg.services.as_u64()); + self.services = version_msg.services; self.user_agent = Some(version_msg.user_agent.clone()); self.best_height = Some(version_msg.start_height as u32); self.relay = Some(version_msg.relay); @@ -824,12 +824,7 @@ impl Peer { // We can request headers2 if peer has the service flag for headers2 support // Note: We don't wait for SendHeaders2 from peer as that creates a race condition // during initial sync. The service flag is sufficient to know they support headers2. - if let Some(services) = self.services { - dashcore::network::constants::ServiceFlags::from(services) - .has(dashcore::network::constants::NODE_HEADERS_COMPRESSED) - } else { - false - } + self.services.has(ServiceFlags::NODE_HEADERS_COMPRESSED) } } diff --git a/dash/src/network/address.rs b/dash/src/network/address.rs index 3d7ef89a7..7903e7830 100644 --- a/dash/src/network/address.rs +++ b/dash/src/network/address.rs @@ -308,7 +308,10 @@ impl Encodable for AddrV2Message { fn consensus_encode(&self, w: &mut W) -> Result { let mut len = 0; len += self.time.consensus_encode(w)?; - len += VarInt(self.services.as_u64()).consensus_encode(w)?; + // This msg encodes ServiceFlags as a VarInt, so we need to + // use the specialized method for it. Don't use consensus_encode + // since it encodes as a u64, not a VarInt. + len += self.services.consensus_encode_as_var_int(w)?; len += self.addr.consensus_encode(w)?; w.write_all(&self.port.to_be_bytes())?; @@ -322,7 +325,10 @@ impl Decodable for AddrV2Message { fn consensus_decode(r: &mut R) -> Result { Ok(AddrV2Message { time: Decodable::consensus_decode(r)?, - services: ServiceFlags::from(VarInt::consensus_decode(r)?.0), + // This msg encodes ServiceFlags as a VarInt, so we need to + // use the specialized method for it. Don't use consensus_decode + // since it decodes as a u64, not a VarInt. + services: ServiceFlags::consensus_decode_from_var_int(r)?, addr: Decodable::consensus_decode(r)?, port: u16::swap_bytes(Decodable::consensus_decode(r)?), }) @@ -365,12 +371,13 @@ mod test { #[test] fn debug_format_test() { - let mut flags = ServiceFlags::NETWORK; + let mut services = ServiceFlags::NETWORK; + services.add(ServiceFlags::WITNESS); assert_eq!( format!( "The address is: {:?}", Address { - services: flags.add(ServiceFlags::WITNESS), + services, address: [0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001], port: 8333 } @@ -412,16 +419,20 @@ mod test { #[test] fn test_socket_addr() { + let mut services = ServiceFlags::NETWORK; + services.add(ServiceFlags::WITNESS); + let s4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(111, 222, 123, 4)), 5555); - let a4 = Address::new(&s4, ServiceFlags::NETWORK | ServiceFlags::WITNESS); + let a4 = Address::new(&s4, services); assert_eq!(a4.socket_addr().unwrap(), s4); + let s6 = SocketAddr::new( IpAddr::V6(Ipv6Addr::new( 0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666, 0x7777, 0x8888, )), 9999, ); - let a6 = Address::new(&s6, ServiceFlags::NETWORK | ServiceFlags::WITNESS); + let a6 = Address::new(&s6, services); assert_eq!(a6.socket_addr().unwrap(), s6); } @@ -577,19 +588,23 @@ mod test { let raw = hex!("0261bc6649019902abab208d79627683fd4804010409090909208d"); let addresses: Vec = deserialize(&raw).unwrap(); + let services1 = ServiceFlags::NETWORK; + + let mut services2 = ServiceFlags::NETWORK_LIMITED; + services2.add(ServiceFlags::WITNESS); + services2.add(ServiceFlags::COMPACT_FILTERS); + assert_eq!( addresses, vec![ AddrV2Message { - services: ServiceFlags::NETWORK, + services: services1, time: 0x4966bc61, port: 8333, addr: AddrV2::Unknown(153, hex!("abab")) }, AddrV2Message { - services: ServiceFlags::NETWORK_LIMITED - | ServiceFlags::WITNESS - | ServiceFlags::COMPACT_FILTERS, + services: services2, time: 0x83766279, port: 8333, addr: AddrV2::Ipv4(Ipv4Addr::new(9, 9, 9, 9)) diff --git a/dash/src/network/constants.rs b/dash/src/network/constants.rs index d7ba1fef2..4c79f34cb 100644 --- a/dash/src/network/constants.rs +++ b/dash/src/network/constants.rs @@ -33,16 +33,12 @@ //! assert_eq!(&bytes[..], &[0xBF, 0x0C, 0x6B, 0xBD]); //! ``` -use core::convert::From; -use core::{fmt, ops}; +use core::fmt; use hashes::Hash; use crate::consensus::encode::{self, Decodable, Encodable}; -use crate::{BlockHash, io}; - -// Re-export NODE_HEADERS_COMPRESSED for convenience -pub const NODE_HEADERS_COMPRESSED: ServiceFlags = ServiceFlags::NODE_HEADERS_COMPRESSED; +use crate::{BlockHash, VarInt, io}; /// Version of the protocol as appearing in network message headers /// This constant is used to signal to other peers which features you support. @@ -231,30 +227,44 @@ impl ServiceFlags { // NOTE: When adding new flags, remember to update the Display impl accordingly. /// Add [ServiceFlags] together. - /// - /// Returns itself. - pub fn add(&mut self, other: ServiceFlags) -> ServiceFlags { + pub fn add(&mut self, other: ServiceFlags) { self.0 |= other.0; - *self } /// Remove [ServiceFlags] from this. - /// - /// Returns itself. - pub fn remove(&mut self, other: ServiceFlags) -> ServiceFlags { + pub fn remove(&mut self, other: ServiceFlags) { self.0 ^= other.0; - *self } /// Check whether [ServiceFlags] are included in this one. - pub fn has(self, flags: ServiceFlags) -> bool { + pub fn has(&self, flags: ServiceFlags) -> bool { (self.0 | flags.0) == self.0 } /// Get the integer representation of this [ServiceFlags]. - pub fn as_u64(self) -> u64 { + pub fn as_u64(&self) -> u64 { self.0 } + + // This struct is weird in the dash protocol, sometime services are encoded as u64 + // and sometimes as a VarInt. While the Encodable/Decodable encodes and decodes the u64 + // as usual, this methods use VarInt to satisfy the protocol + + #[inline] + pub fn consensus_encode_as_var_int( + &self, + w: &mut W, + ) -> Result { + self.0.consensus_encode(w) + } + + #[inline] + pub fn consensus_decode_from_var_int( + r: &mut R, + ) -> Result { + let services = VarInt::consensus_decode(r)?; + Ok(ServiceFlags(services.0)) + } } impl fmt::LowerHex for ServiceFlags { @@ -307,47 +317,10 @@ impl fmt::Display for ServiceFlags { } } -impl From for ServiceFlags { - fn from(f: u64) -> Self { - ServiceFlags(f) - } -} - -impl From for u64 { - fn from(val: ServiceFlags) -> Self { - val.0 - } -} - -impl ops::BitOr for ServiceFlags { - type Output = Self; - - fn bitor(mut self, rhs: Self) -> Self { - self.add(rhs) - } -} - -impl ops::BitOrAssign for ServiceFlags { - fn bitor_assign(&mut self, rhs: Self) { - self.add(rhs); - } -} - -impl ops::BitXor for ServiceFlags { - type Output = Self; - - fn bitxor(mut self, rhs: Self) -> Self { - self.remove(rhs) - } -} - -impl ops::BitXorAssign for ServiceFlags { - fn bitxor_assign(&mut self, rhs: Self) { - self.remove(rhs); - } -} - impl Encodable for ServiceFlags { + /// Encodes the service flags as a u64, not a VarInt. Services are usually encoded as a u64 + /// but there are some messages that encode them as a VarInt instead. For those use the + /// specialized method `consensus_encode_as_var_int`. #[inline] fn consensus_encode(&self, w: &mut W) -> Result { self.0.consensus_encode(w) @@ -355,6 +328,9 @@ impl Encodable for ServiceFlags { } impl Decodable for ServiceFlags { + /// Decodes the service flags as a u64, not a VarInt. Services are usually decoded as a u64 + /// but there are some messages that decode them as a VarInt instead. For those use the + /// specialized method `consensus_decode_as_var_int`. #[inline] fn consensus_decode(r: &mut R) -> Result { Ok(ServiceFlags(Decodable::consensus_decode(r)?)) @@ -434,27 +410,28 @@ mod tests { assert!(!flags.has(*f)); } - flags |= ServiceFlags::WITNESS; + flags.add(ServiceFlags::WITNESS); assert_eq!(flags, ServiceFlags::WITNESS); - let mut flags2 = flags | ServiceFlags::GETUTXO; + flags.add(ServiceFlags::GETUTXO); for f in all.iter() { - assert_eq!(flags2.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO); + assert_eq!(flags.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO); } - flags2 ^= ServiceFlags::WITNESS; - assert_eq!(flags2, ServiceFlags::GETUTXO); + flags.remove(ServiceFlags::WITNESS); + assert_eq!(flags, ServiceFlags::GETUTXO); - flags2 |= ServiceFlags::COMPACT_FILTERS; - flags2 ^= ServiceFlags::GETUTXO; - assert_eq!(flags2, ServiceFlags::COMPACT_FILTERS); + flags.add(ServiceFlags::COMPACT_FILTERS); + flags.remove(ServiceFlags::GETUTXO); + assert_eq!(flags, ServiceFlags::COMPACT_FILTERS); // Test formatting. assert_eq!("ServiceFlags(NONE)", ServiceFlags::NONE.to_string()); assert_eq!("ServiceFlags(WITNESS)", ServiceFlags::WITNESS.to_string()); - let flag = ServiceFlags::WITNESS | ServiceFlags::BLOOM | ServiceFlags::NETWORK; - assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flag.to_string()); - let flag = ServiceFlags::WITNESS | 0xf0.into(); - assert_eq!("ServiceFlags(WITNESS|COMPACT_FILTERS|0xb0)", flag.to_string()); + + let mut flags = ServiceFlags::WITNESS; + flags.add(ServiceFlags::BLOOM); + flags.add(ServiceFlags::NETWORK); + assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flags.to_string()); } } diff --git a/dash/src/network/message.rs b/dash/src/network/message.rs index 06f4713bf..44fab4776 100644 --- a/dash/src/network/message.rs +++ b/dash/src/network/message.rs @@ -934,13 +934,11 @@ mod test { assert_eq!(msg.magic, 0xd9b4bef9); if let NetworkMessage::Version(version_msg) = msg.payload { assert_eq!(version_msg.version, 70015); - assert_eq!( - version_msg.services, - ServiceFlags::NETWORK - | ServiceFlags::BLOOM - | ServiceFlags::WITNESS - | ServiceFlags::NETWORK_LIMITED - ); + let mut expected_services = ServiceFlags::NETWORK; + expected_services.add(ServiceFlags::BLOOM); + expected_services.add(ServiceFlags::WITNESS); + expected_services.add(ServiceFlags::NETWORK_LIMITED); + assert_eq!(version_msg.services, expected_services); assert_eq!(version_msg.timestamp, 1548554224); assert_eq!(version_msg.nonce, 13952548347456104954); assert_eq!(version_msg.user_agent, "/Satoshi:0.17.1/"); @@ -979,13 +977,11 @@ mod test { assert_eq!(msg.magic, 0xd9b4bef9); if let NetworkMessage::Version(version_msg) = msg.payload { assert_eq!(version_msg.version, 70015); - assert_eq!( - version_msg.services, - ServiceFlags::NETWORK - | ServiceFlags::BLOOM - | ServiceFlags::WITNESS - | ServiceFlags::NETWORK_LIMITED - ); + let mut expected_services = ServiceFlags::NETWORK; + expected_services.add(ServiceFlags::BLOOM); + expected_services.add(ServiceFlags::WITNESS); + expected_services.add(ServiceFlags::NETWORK_LIMITED); + assert_eq!(version_msg.services, expected_services); assert_eq!(version_msg.timestamp, 1548554224); assert_eq!(version_msg.nonce, 13952548347456104954); assert_eq!(version_msg.user_agent, "/Satoshi:0.17.1/");