Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions dash-spv/src/network/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::net::SocketAddr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use dashcore::network::constants;
use dashcore::network::constants::{ServiceFlags, NODE_HEADERS_COMPRESSED};
use dashcore::network::constants::ServiceFlags;
use dashcore::network::message::NetworkMessage;
use dashcore::network::message_network::VersionMessage;
use dashcore::Network;
Expand Down Expand Up @@ -36,7 +36,7 @@ pub struct HandshakeManager {
state: HandshakeState,
our_version: u32,
peer_version: Option<u32>,
peer_services: Option<ServiceFlags>,
peer_services: ServiceFlags,
version_received: bool,
verack_received: bool,
version_sent: bool,
Expand All @@ -56,7 +56,7 @@ impl HandshakeManager {
state: HandshakeState::Init,
our_version: constants::PROTOCOL_VERSION,
peer_version: None,
peer_services: None,
peer_services: ServiceFlags::NONE,
version_received: false,
verack_received: false,
version_sent: false,
Expand Down Expand Up @@ -157,7 +157,7 @@ impl HandshakeManager {
version_msg
);
self.peer_version = Some(version_msg.version);
self.peer_services = Some(version_msg.services);
self.peer_services = version_msg.services;
self.version_received = true;

// Update connection's peer information
Expand Down Expand Up @@ -261,7 +261,7 @@ impl HandshakeManager {
.as_secs() as i64;

// Advertise headers2 support (NODE_HEADERS_COMPRESSED)
let services = ServiceFlags::NONE | NODE_HEADERS_COMPRESSED;
let services = ServiceFlags::NODE_HEADERS_COMPRESSED;

// Parse the local address safely
let local_addr = "127.0.0.1:0"
Expand Down Expand Up @@ -313,7 +313,7 @@ impl HandshakeManager {

/// Check if peer supports headers2 compression.
pub fn peer_supports_headers2(&self) -> bool {
self.peer_services.map(|services| services.has(NODE_HEADERS_COMPRESSED)).unwrap_or(false)
self.peer_services.has(ServiceFlags::NODE_HEADERS_COMPRESSED)
}

/// Negotiate headers2 support with the peer after handshake completion.
Expand Down
17 changes: 6 additions & 11 deletions dash-spv/src/network/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub struct Peer {
pending_pings: HashMap<u64, SystemTime>, // nonce -> sent_time
// Peer information from Version message
version: Option<u32>,
services: Option<u64>,
services: ServiceFlags,
user_agent: Option<String>,
best_height: Option<u32>,
relay: Option<bool>,
Expand Down Expand Up @@ -68,7 +68,7 @@ impl Peer {
last_pong_received: None,
pending_pings: HashMap::new(),
version: None,
services: None,
services: ServiceFlags::NONE,
user_agent: None,
best_height: None,
relay: None,
Expand Down Expand Up @@ -115,7 +115,7 @@ impl Peer {
last_pong_received: None,
pending_pings: HashMap::new(),
version: None,
services: None,
services: ServiceFlags::NONE,
user_agent: None,
best_height: None,
relay: None,
Expand Down Expand Up @@ -144,7 +144,7 @@ impl Peer {
}

pub fn has_service(&self, flags: ServiceFlags) -> bool {
self.services.map(|s| ServiceFlags::from(s).has(flags)).unwrap_or(false)
self.services.has(flags)
}

/// Connect to the peer (instance method for compatibility).
Expand Down Expand Up @@ -273,7 +273,7 @@ impl Peer {

// All validations passed, update peer info
self.version = Some(version_msg.version);
self.services = Some(version_msg.services.as_u64());
self.services = version_msg.services;
self.user_agent = Some(version_msg.user_agent.clone());
self.best_height = Some(version_msg.start_height as u32);
self.relay = Some(version_msg.relay);
Expand Down Expand Up @@ -824,12 +824,7 @@ impl Peer {
// We can request headers2 if peer has the service flag for headers2 support
// Note: We don't wait for SendHeaders2 from peer as that creates a race condition
// during initial sync. The service flag is sufficient to know they support headers2.
if let Some(services) = self.services {
dashcore::network::constants::ServiceFlags::from(services)
.has(dashcore::network::constants::NODE_HEADERS_COMPRESSED)
} else {
false
}
self.services.has(ServiceFlags::NODE_HEADERS_COMPRESSED)
}
}

Expand Down
35 changes: 25 additions & 10 deletions dash/src/network/address.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ impl Encodable for AddrV2Message {
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
let mut len = 0;
len += self.time.consensus_encode(w)?;
len += VarInt(self.services.as_u64()).consensus_encode(w)?;
// This msg encodes ServiceFlags as a VarInt, so we need to
// use the specialized method for it. Don't use consensus_encode
// since it encodes as a u64, not a VarInt.
len += self.services.consensus_encode_as_var_int(w)?;
len += self.addr.consensus_encode(w)?;

w.write_all(&self.port.to_be_bytes())?;
Expand All @@ -322,7 +325,10 @@ impl Decodable for AddrV2Message {
fn consensus_decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, encode::Error> {
Ok(AddrV2Message {
time: Decodable::consensus_decode(r)?,
services: ServiceFlags::from(VarInt::consensus_decode(r)?.0),
// This msg encodes ServiceFlags as a VarInt, so we need to
// use the specialized method for it. Don't use consensus_decode
// since it decodes as a u64, not a VarInt.
services: ServiceFlags::consensus_decode_from_var_int(r)?,
addr: Decodable::consensus_decode(r)?,
port: u16::swap_bytes(Decodable::consensus_decode(r)?),
})
Expand Down Expand Up @@ -365,12 +371,13 @@ mod test {

#[test]
fn debug_format_test() {
let mut flags = ServiceFlags::NETWORK;
let mut services = ServiceFlags::NETWORK;
services.add(ServiceFlags::WITNESS);
assert_eq!(
format!(
"The address is: {:?}",
Address {
services: flags.add(ServiceFlags::WITNESS),
services,
address: [0, 0, 0, 0, 0, 0xffff, 0x0a00, 0x0001],
port: 8333
}
Expand Down Expand Up @@ -412,16 +419,20 @@ mod test {

#[test]
fn test_socket_addr() {
let mut services = ServiceFlags::NETWORK;
services.add(ServiceFlags::WITNESS);

let s4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(111, 222, 123, 4)), 5555);
let a4 = Address::new(&s4, ServiceFlags::NETWORK | ServiceFlags::WITNESS);
let a4 = Address::new(&s4, services);
assert_eq!(a4.socket_addr().unwrap(), s4);

let s6 = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(
0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666, 0x7777, 0x8888,
)),
9999,
);
let a6 = Address::new(&s6, ServiceFlags::NETWORK | ServiceFlags::WITNESS);
let a6 = Address::new(&s6, services);
assert_eq!(a6.socket_addr().unwrap(), s6);
}

Expand Down Expand Up @@ -577,19 +588,23 @@ mod test {
let raw = hex!("0261bc6649019902abab208d79627683fd4804010409090909208d");
let addresses: Vec<AddrV2Message> = deserialize(&raw).unwrap();

let services1 = ServiceFlags::NETWORK;

let mut services2 = ServiceFlags::NETWORK_LIMITED;
services2.add(ServiceFlags::WITNESS);
services2.add(ServiceFlags::COMPACT_FILTERS);

assert_eq!(
addresses,
vec![
AddrV2Message {
services: ServiceFlags::NETWORK,
services: services1,
time: 0x4966bc61,
port: 8333,
addr: AddrV2::Unknown(153, hex!("abab"))
},
AddrV2Message {
services: ServiceFlags::NETWORK_LIMITED
| ServiceFlags::WITNESS
| ServiceFlags::COMPACT_FILTERS,
services: services2,
time: 0x83766279,
port: 8333,
addr: AddrV2::Ipv4(Ipv4Addr::new(9, 9, 9, 9))
Expand Down
113 changes: 45 additions & 68 deletions dash/src/network/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,12 @@
//! assert_eq!(&bytes[..], &[0xBF, 0x0C, 0x6B, 0xBD]);
//! ```

use core::convert::From;
use core::{fmt, ops};
use core::fmt;

use hashes::Hash;

use crate::consensus::encode::{self, Decodable, Encodable};
use crate::{BlockHash, io};

// Re-export NODE_HEADERS_COMPRESSED for convenience
pub const NODE_HEADERS_COMPRESSED: ServiceFlags = ServiceFlags::NODE_HEADERS_COMPRESSED;
use crate::{BlockHash, VarInt, io};

/// Version of the protocol as appearing in network message headers
/// This constant is used to signal to other peers which features you support.
Expand Down Expand Up @@ -231,30 +227,44 @@ impl ServiceFlags {
// NOTE: When adding new flags, remember to update the Display impl accordingly.

/// Add [ServiceFlags] together.
///
/// Returns itself.
pub fn add(&mut self, other: ServiceFlags) -> ServiceFlags {
pub fn add(&mut self, other: ServiceFlags) {
self.0 |= other.0;
*self
}

/// Remove [ServiceFlags] from this.
///
/// Returns itself.
pub fn remove(&mut self, other: ServiceFlags) -> ServiceFlags {
pub fn remove(&mut self, other: ServiceFlags) {
self.0 ^= other.0;
*self
}

/// Check whether [ServiceFlags] are included in this one.
pub fn has(self, flags: ServiceFlags) -> bool {
pub fn has(&self, flags: ServiceFlags) -> bool {
(self.0 | flags.0) == self.0
}

/// Get the integer representation of this [ServiceFlags].
pub fn as_u64(self) -> u64 {
pub fn as_u64(&self) -> u64 {
self.0
}

// This struct is weird in the dash protocol, sometime services are encoded as u64
// and sometimes as a VarInt. While the Encodable/Decodable encodes and decodes the u64
// as usual, this methods use VarInt to satisfy the protocol

#[inline]
pub fn consensus_encode_as_var_int<W: io::Write + ?Sized>(
&self,
w: &mut W,
) -> Result<usize, io::Error> {
self.0.consensus_encode(w)
}

#[inline]
pub fn consensus_decode_from_var_int<R: io::Read + ?Sized>(
r: &mut R,
) -> Result<Self, encode::Error> {
let services = VarInt::consensus_decode(r)?;
Ok(ServiceFlags(services.0))
}
}

impl fmt::LowerHex for ServiceFlags {
Expand Down Expand Up @@ -307,54 +317,20 @@ impl fmt::Display for ServiceFlags {
}
}

impl From<u64> for ServiceFlags {
fn from(f: u64) -> Self {
ServiceFlags(f)
}
}

impl From<ServiceFlags> for u64 {
fn from(val: ServiceFlags) -> Self {
val.0
}
}

impl ops::BitOr for ServiceFlags {
type Output = Self;

fn bitor(mut self, rhs: Self) -> Self {
self.add(rhs)
}
}

impl ops::BitOrAssign for ServiceFlags {
fn bitor_assign(&mut self, rhs: Self) {
self.add(rhs);
}
}

impl ops::BitXor for ServiceFlags {
type Output = Self;

fn bitxor(mut self, rhs: Self) -> Self {
self.remove(rhs)
}
}

impl ops::BitXorAssign for ServiceFlags {
fn bitxor_assign(&mut self, rhs: Self) {
self.remove(rhs);
}
}

impl Encodable for ServiceFlags {
/// Encodes the service flags as a u64, not a VarInt. Services are usually encoded as a u64
/// but there are some messages that encode them as a VarInt instead. For those use the
/// specialized method `consensus_encode_as_var_int`.
#[inline]
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
self.0.consensus_encode(w)
}
}

impl Decodable for ServiceFlags {
/// Decodes the service flags as a u64, not a VarInt. Services are usually decoded as a u64
/// but there are some messages that decode them as a VarInt instead. For those use the
/// specialized method `consensus_decode_as_var_int`.
#[inline]
fn consensus_decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, encode::Error> {
Ok(ServiceFlags(Decodable::consensus_decode(r)?))
Expand Down Expand Up @@ -434,27 +410,28 @@ mod tests {
assert!(!flags.has(*f));
}

flags |= ServiceFlags::WITNESS;
flags.add(ServiceFlags::WITNESS);
assert_eq!(flags, ServiceFlags::WITNESS);

let mut flags2 = flags | ServiceFlags::GETUTXO;
flags.add(ServiceFlags::GETUTXO);
for f in all.iter() {
assert_eq!(flags2.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO);
assert_eq!(flags.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO);
}

flags2 ^= ServiceFlags::WITNESS;
assert_eq!(flags2, ServiceFlags::GETUTXO);
flags.remove(ServiceFlags::WITNESS);
assert_eq!(flags, ServiceFlags::GETUTXO);

flags2 |= ServiceFlags::COMPACT_FILTERS;
flags2 ^= ServiceFlags::GETUTXO;
assert_eq!(flags2, ServiceFlags::COMPACT_FILTERS);
flags.add(ServiceFlags::COMPACT_FILTERS);
flags.remove(ServiceFlags::GETUTXO);
assert_eq!(flags, ServiceFlags::COMPACT_FILTERS);

// Test formatting.
assert_eq!("ServiceFlags(NONE)", ServiceFlags::NONE.to_string());
assert_eq!("ServiceFlags(WITNESS)", ServiceFlags::WITNESS.to_string());
let flag = ServiceFlags::WITNESS | ServiceFlags::BLOOM | ServiceFlags::NETWORK;
assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flag.to_string());
let flag = ServiceFlags::WITNESS | 0xf0.into();
assert_eq!("ServiceFlags(WITNESS|COMPACT_FILTERS|0xb0)", flag.to_string());

let mut flags = ServiceFlags::WITNESS;
flags.add(ServiceFlags::BLOOM);
flags.add(ServiceFlags::NETWORK);
assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flags.to_string());
}
}
Loading
Loading