From 9efc576d3157bda6b79a7afb2553cd2e5bcd2998 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Fri, 6 Feb 2026 18:32:08 +0300 Subject: [PATCH 01/11] Implement TreeNodeIndex --- compiler/rustc_data_structures/src/lib.rs | 1 + .../src/tree_node_index.rs | 86 +++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 compiler/rustc_data_structures/src/tree_node_index.rs diff --git a/compiler/rustc_data_structures/src/lib.rs b/compiler/rustc_data_structures/src/lib.rs index b01834aa80d9d..aff625b0fe7a9 100644 --- a/compiler/rustc_data_structures/src/lib.rs +++ b/compiler/rustc_data_structures/src/lib.rs @@ -92,6 +92,7 @@ pub mod temp_dir; pub mod thinvec; pub mod thousands; pub mod transitive_relation; +pub mod tree_node_index; pub mod unhash; pub mod union_find; pub mod unord; diff --git a/compiler/rustc_data_structures/src/tree_node_index.rs b/compiler/rustc_data_structures/src/tree_node_index.rs new file mode 100644 index 0000000000000..dfe03f9776cab --- /dev/null +++ b/compiler/rustc_data_structures/src/tree_node_index.rs @@ -0,0 +1,86 @@ +use std::error::Error; +use std::fmt::Display; + +/// Ordered index for dynamic trees +/// +/// ## Encoding +/// +/// You can index any node of a binary tree with a variable-length bitstring. +/// Each bit should represent a branch/edge to traverse up from root to down to indexed node. +/// To encode a variable-length bitstring we use bits of u64 ordered from highest to lowest. +/// Right after the encoded sequence of bits we set one bit high to recover sequence's length: +/// +/// ```text +/// 0bXXXXXXX100000000...0 +/// ``` +/// +/// Node reach after traversal of `LRLRRLLR` branches should be represented as `0b0101100110000...0`. +/// Root is obviously encoded as `0b10000...0`. +/// +/// ## Order +/// +/// Encoding allows to sort nodes in left < parent < right linear order. +/// If you only consider leaves of a tree then those are sorted in order left < right. +/// +/// ## Used in +/// +/// Primary purpose of `TreeNodeIndex` is to track order of parallel tasks of functions like `join` +/// or `scope` (see `rustc_middle::sync`). +/// This is done in query cycle handling code to determine **intended** first task for a single- +/// threaded compiler front-end to execute even while multi-threaded. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct TreeNodeIndex(pub u64); + +impl TreeNodeIndex { + pub const fn root() -> Self { + Self(0x80000000_00000000) + } + + /// Append tree branch no. `branch_idx` reserving `bits` bits. + fn try_bits_branch(self, branch_idx: u64, bits: u32) -> Result { + let trailing_zeros = self.0.trailing_zeros(); + let allocated_shift = trailing_zeros.checked_sub(bits).ok_or(BranchingError(()))?; + Ok(TreeNodeIndex( + self.0 & !(1 << trailing_zeros) + | (1 << allocated_shift) + | (branch_idx << (allocated_shift + 1)), + )) + } + + /// Append tree branch no. `branch_idx` reserving `ceil(log2(branch_num))` bits. + pub fn branch(self, branch_idx: u64, branch_num: u64) -> TreeNodeIndex { + debug_assert!( + branch_idx < branch_num, + "branch_idx = {branch_idx} should be less than branch_num = {branch_num}" + ); + // floor(log2(n - 1)) + 1 == ceil(log2(n)) + self.try_bits_branch(branch_idx, (branch_num - 1).checked_ilog2().map_or(0, |b| b + 1)) + .unwrap() + } + + pub fn try_concat(self, then: Self) -> Result { + let trailing_zeros = then.0.trailing_zeros(); + let branch_num = then.0.wrapping_shr(trailing_zeros + 1); + let bits = u64::BITS - trailing_zeros; + self.try_bits_branch(branch_num, bits) + } +} + +/// Error for exhausting free bits +#[derive(Debug)] +pub struct BranchingError(()); + +impl Error for BranchingError {} + +impl Display for BranchingError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + "TreeNodeIndex's free bits have been exhausted, make sure recursion is used carefully" + .fmt(f) + } +} + +impl Default for TreeNodeIndex { + fn default() -> Self { + TreeNodeIndex::root() + } +} From ffa61e077bdb6810e195a787d4e4d83838d71f43 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Fri, 6 Feb 2026 19:18:01 +0300 Subject: [PATCH 02/11] Integrate QueryInclusion into QueryJob --- compiler/rustc_interface/src/interface.rs | 2 +- compiler/rustc_middle/src/query/job.rs | 15 ++++++++---- compiler/rustc_middle/src/query/mod.rs | 2 +- compiler/rustc_middle/src/ty/context/tls.rs | 7 +++--- compiler/rustc_query_impl/src/execution.rs | 17 +++++++------- compiler/rustc_query_impl/src/job.rs | 26 ++++++++++----------- compiler/rustc_query_impl/src/plumbing.rs | 7 +++--- 7 files changed, 42 insertions(+), 34 deletions(-) diff --git a/compiler/rustc_interface/src/interface.rs b/compiler/rustc_interface/src/interface.rs index 91b7f234d5f64..c9f1103ce4466 100644 --- a/compiler/rustc_interface/src/interface.rs +++ b/compiler/rustc_interface/src/interface.rs @@ -546,7 +546,7 @@ pub fn try_print_query_stack( if let Some(icx) = icx { ty::print::with_no_queries!(print_query_stack( icx.tcx, - icx.query, + icx.query.map(|i| i.id), dcx, limit_frames, file, diff --git a/compiler/rustc_middle/src/query/job.rs b/compiler/rustc_middle/src/query/job.rs index f1a2b3a34d0e8..e98ca70db5d13 100644 --- a/compiler/rustc_middle/src/query/job.rs +++ b/compiler/rustc_middle/src/query/job.rs @@ -4,6 +4,7 @@ use std::num::NonZero; use std::sync::Arc; use parking_lot::{Condvar, Mutex}; +use rustc_data_structures::tree_node_index::TreeNodeIndex; use rustc_span::Span; use crate::query::plumbing::CycleError; @@ -37,7 +38,7 @@ pub struct QueryJob<'tcx> { pub span: Span, /// The parent query job which created this job and is implicitly waiting on it. - pub parent: Option, + pub parent: Option, /// The latch that is used to wait on this job. pub latch: Option>, @@ -46,7 +47,7 @@ pub struct QueryJob<'tcx> { impl<'tcx> QueryJob<'tcx> { /// Creates a new query job. #[inline] - pub fn new(id: QueryJobId, span: Span, parent: Option) -> Self { + pub fn new(id: QueryJobId, span: Span, parent: Option) -> Self { QueryJob { id, span, parent, latch: None } } @@ -69,9 +70,15 @@ impl<'tcx> QueryJob<'tcx> { } } +#[derive(Clone, Copy, Debug)] +pub struct QueryInclusion { + pub id: QueryJobId, + pub branch: TreeNodeIndex, +} + #[derive(Debug)] pub struct QueryWaiter<'tcx> { - pub query: Option, + pub query: Option, pub condvar: Condvar, pub span: Span, pub cycle: Mutex>>>, @@ -99,7 +106,7 @@ impl<'tcx> QueryLatch<'tcx> { pub fn wait_on( &self, tcx: TyCtxt<'tcx>, - query: Option, + query: Option, span: Span, ) -> Result<(), CycleError>> { let waiter = diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 66e4a77ea6a51..22b1faba7634c 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -3,7 +3,7 @@ use rustc_hir::def_id::LocalDefId; pub use self::caches::{ DefIdCache, DefaultCache, QueryCache, QueryCacheKey, SingleCache, VecCache, }; -pub use self::job::{QueryInfo, QueryJob, QueryJobId, QueryLatch, QueryWaiter}; +pub use self::job::{QueryInclusion, QueryInfo, QueryJob, QueryJobId, QueryLatch, QueryWaiter}; pub use self::keys::{AsLocalKey, Key, LocalCrate}; pub use self::plumbing::{ ActiveKeyStatus, CycleError, CycleErrorHandling, IntoQueryParam, QueryMode, QueryState, diff --git a/compiler/rustc_middle/src/ty/context/tls.rs b/compiler/rustc_middle/src/ty/context/tls.rs index d37ad56c2e83d..26b391623cdf7 100644 --- a/compiler/rustc_middle/src/ty/context/tls.rs +++ b/compiler/rustc_middle/src/ty/context/tls.rs @@ -4,7 +4,7 @@ use rustc_data_structures::sync; use super::{GlobalCtxt, TyCtxt}; use crate::dep_graph::TaskDepsRef; -use crate::query::QueryJobId; +use crate::query::QueryInclusion; /// This is the implicit state of rustc. It contains the current /// `TyCtxt` and query. It is updated when creating a local interner or @@ -16,8 +16,9 @@ pub struct ImplicitCtxt<'a, 'tcx> { /// The current `TyCtxt`. pub tcx: TyCtxt<'tcx>, - /// The current query job, if any. - pub query: Option, + /// The current query job, if any. This is updated by `JobOwner::start` in + /// `ty::query::plumbing` when executing a query. + pub query: Option, /// Used to prevent queries from calling too deeply. pub query_depth: usize, diff --git a/compiler/rustc_query_impl/src/execution.rs b/compiler/rustc_query_impl/src/execution.rs index 53afcacb63a6c..22992d094afa1 100644 --- a/compiler/rustc_query_impl/src/execution.rs +++ b/compiler/rustc_query_impl/src/execution.rs @@ -8,8 +8,8 @@ use rustc_errors::{Diag, FatalError, StashKey}; use rustc_middle::dep_graph::{DepGraphData, DepNodeKey}; use rustc_middle::query::plumbing::QueryVTable; use rustc_middle::query::{ - ActiveKeyStatus, CycleError, CycleErrorHandling, QueryCache, QueryJob, QueryJobId, QueryLatch, - QueryMode, QueryStackDeferred, QueryStackFrame, QueryState, + ActiveKeyStatus, CycleError, CycleErrorHandling, QueryCache, QueryInclusion, QueryJob, + QueryJobId, QueryLatch, QueryMode, QueryStackDeferred, QueryStackFrame, QueryState, }; use rustc_middle::ty::TyCtxt; use rustc_middle::verify_ich::incremental_verify_ich; @@ -18,7 +18,7 @@ use rustc_span::{DUMMY_SP, Span}; use crate::dep_graph::{DepNode, DepNodeIndex}; use crate::job::{QueryJobInfo, QueryJobMap, find_cycle_in_stack, report_cycle}; use crate::plumbing::{ - collect_active_jobs_from_all_queries, current_query_job, next_job_id, start_query, + collect_active_jobs_from_all_queries, current_query_inclusion, next_job_id, start_query, }; #[inline] @@ -220,7 +220,8 @@ fn cycle_error<'tcx, C: QueryCache>( .ok() .expect("failed to collect active queries"); - let error = find_cycle_in_stack(try_execute, job_map, ¤t_query_job(tcx), span); + let error = + find_cycle_in_stack(try_execute, job_map, current_query_inclusion(tcx).map(|i| i.id), span); (mk_cycle(query, tcx, error.lift()), None) } @@ -231,7 +232,7 @@ fn wait_for_query<'tcx, C: QueryCache>( span: Span, key: C::Key, latch: QueryLatch<'tcx>, - current: Option, + current: Option, ) -> (C::Value, Option) { // For parallel queries, we'll block and wait until the query running // in another thread has completed. Record how long we wait in the @@ -295,14 +296,14 @@ fn try_execute_query<'tcx, C: QueryCache, const INCR: bool>( } } - let current_job_id = current_query_job(tcx); + let current_inclusion = current_query_inclusion(tcx); match state_lock.entry(key_hash, equivalent_key(&key), |(k, _)| sharded::make_hash(k)) { Entry::Vacant(entry) => { // Nothing has computed or is computing the query, so we start a new job and insert it in the // state map. let id = next_job_id(tcx); - let job = QueryJob::new(id, span, current_job_id); + let job = QueryJob::new(id, span, current_inclusion); entry.insert((key, ActiveKeyStatus::Started(job))); // Drop the lock before we start executing the query @@ -320,7 +321,7 @@ fn try_execute_query<'tcx, C: QueryCache, const INCR: bool>( // Only call `wait_for_query` if we're using a Rayon thread pool // as it will attempt to mark the worker thread as blocked. - return wait_for_query(query, tcx, span, key, latch, current_job_id); + return wait_for_query(query, tcx, span, key, latch, current_inclusion); } let id = job.id; diff --git a/compiler/rustc_query_impl/src/job.rs b/compiler/rustc_query_impl/src/job.rs index 2d9824a783ea5..0a1e984c560b5 100644 --- a/compiler/rustc_query_impl/src/job.rs +++ b/compiler/rustc_query_impl/src/job.rs @@ -7,8 +7,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_errors::{Diag, DiagCtxtHandle}; use rustc_hir::def::DefKind; use rustc_middle::query::{ - CycleError, QueryInfo, QueryJob, QueryJobId, QueryLatch, QueryStackDeferred, QueryStackFrame, - QueryWaiter, + CycleError, QueryInclusion, QueryInfo, QueryJob, QueryJobId, QueryLatch, QueryStackDeferred, QueryStackFrame, QueryWaiter }; use rustc_middle::ty::TyCtxt; use rustc_session::Session; @@ -39,7 +38,7 @@ impl<'tcx> QueryJobMap<'tcx> { self.map[&id].job.span } - fn parent_of(&self, id: QueryJobId) -> Option { + fn parent_of(&self, id: QueryJobId) -> Option { self.map[&id].job.parent } @@ -57,12 +56,11 @@ pub(crate) struct QueryJobInfo<'tcx> { pub(crate) fn find_cycle_in_stack<'tcx>( id: QueryJobId, job_map: QueryJobMap<'tcx>, - current_job: &Option, + mut current_job: Option, span: Span, ) -> CycleError> { // Find the waitee amongst `current_job` parents let mut cycle = Vec::new(); - let mut current_job = Option::clone(current_job); while let Some(job) = current_job { let info = &job_map.map[&job]; @@ -79,12 +77,12 @@ pub(crate) fn find_cycle_in_stack<'tcx>( // Find out why the cycle itself was used let usage = try { let parent = info.job.parent?; - (info.job.span, job_map.frame_of(parent).clone()) + (info.job.span, job_map.frame_of(parent.id).clone()) }; return CycleError { usage, cycle }; } - current_job = info.job.parent; + current_job = info.job.parent.map(|i| i.id); } panic!("did not find a cycle") @@ -99,16 +97,16 @@ pub(crate) fn find_dep_kind_root<'tcx>( let mut depth = 1; let info = &job_map.map[&id]; let dep_kind = info.frame.dep_kind; - let mut current_id = info.job.parent; + let mut current = info.job.parent; let mut last_layout = (info.clone(), depth); - while let Some(id) = current_id { - let info = &job_map.map[&id]; + while let Some(inclusion) = current { + let info = &job_map.map[&inclusion.id]; if info.frame.dep_kind == dep_kind { depth += 1; last_layout = (info.clone(), depth); } - current_id = info.job.parent; + current = info.job.parent; } last_layout } @@ -131,7 +129,7 @@ fn visit_waiters<'tcx>( ) -> ControlFlow> { // Visit the parent query which is a non-resumable waiter since it's on the same stack if let Some(parent) = job_map.parent_of(query) { - visit(job_map.span_of(query), parent)?; + visit(job_map.span_of(query), parent.id)?; } // Visit the explicit waiters which use condvars and are resumable @@ -139,7 +137,7 @@ fn visit_waiters<'tcx>( for (i, waiter) in latch.info.lock().waiters.iter().enumerate() { if let Some(waiter_query) = waiter.query { // Return a value which indicates that this waiter can be resumed - visit(waiter.span, waiter_query).map_break(|_| Some((query, i)))?; + visit(waiter.span, waiter_query.id).map_break(|_| Some((query, i)))?; } } } @@ -415,7 +413,7 @@ pub fn print_query_stack<'tcx>( ); } - current_query = query_info.job.parent; + current_query = query_info.job.parent.map(|i| i.id); count_total += 1; } diff --git a/compiler/rustc_query_impl/src/plumbing.rs b/compiler/rustc_query_impl/src/plumbing.rs index 11077e8e0ee20..4d9ac6fc0069a 100644 --- a/compiler/rustc_query_impl/src/plumbing.rs +++ b/compiler/rustc_query_impl/src/plumbing.rs @@ -5,6 +5,7 @@ use std::num::NonZero; use rustc_data_structures::sync::{DynSend, DynSync}; +use rustc_data_structures::tree_node_index::TreeNodeIndex; use rustc_data_structures::unord::UnordMap; use rustc_hir::def_id::DefId; use rustc_hir::limit::Limit; @@ -18,7 +19,7 @@ use rustc_middle::query::on_disk_cache::{ }; use rustc_middle::query::plumbing::QueryVTable; use rustc_middle::query::{ - Key, QueryCache, QueryJobId, QueryStackDeferred, QueryStackFrame, QueryStackFrameExtra, + Key, QueryCache, QueryInclusion, QueryJobId, QueryStackDeferred, QueryStackFrame, QueryStackFrameExtra }; use rustc_middle::ty::codec::TyEncoder; use rustc_middle::ty::print::with_reduced_queries; @@ -59,7 +60,7 @@ pub(crate) fn next_job_id<'tcx>(tcx: TyCtxt<'tcx>) -> QueryJobId { } #[inline] -pub(crate) fn current_query_job<'tcx>(tcx: TyCtxt<'tcx>) -> Option { +pub(crate) fn current_query_inclusion<'tcx>(tcx: TyCtxt<'tcx>) -> Option { tls::with_related_context(tcx, |icx| icx.query) } @@ -83,7 +84,7 @@ pub(crate) fn start_query<'tcx, R>( // Update the `ImplicitCtxt` to point to our new query job. let new_icx = ImplicitCtxt { tcx, - query: Some(token), + query: Some(QueryInclusion { id: token, branch: TreeNodeIndex::root() }), query_depth: current_icx.query_depth + depth_limit as usize, task_deps: current_icx.task_deps, }; From 7b8a30b395cbbf0fb7de98df064a6b7151cd6158 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Fri, 6 Feb 2026 18:22:06 +0300 Subject: [PATCH 03/11] Move par_fns, par_join, par_slice and others to rustc_middle --- .../rustc_codegen_cranelift/src/driver/aot.rs | 2 +- compiler/rustc_codegen_ssa/src/base.rs | 2 +- compiler/rustc_data_structures/src/sync.rs | 3 +- .../src/sync/parallel.rs | 159 ----------------- .../rustc_incremental/src/persist/save.rs | 2 +- compiler/rustc_interface/src/passes.rs | 3 +- compiler/rustc_lint/src/late.rs | 2 +- compiler/rustc_metadata/src/rmeta/encoder.rs | 2 +- compiler/rustc_middle/src/hir/map.rs | 3 +- compiler/rustc_middle/src/hir/mod.rs | 2 +- compiler/rustc_middle/src/lib.rs | 1 + compiler/rustc_middle/src/sync.rs | 164 ++++++++++++++++++ compiler/rustc_monomorphize/src/collector.rs | 2 +- .../rustc_monomorphize/src/partitioning.rs | 2 +- src/tools/miri/src/bin/miri.rs | 3 +- 15 files changed, 179 insertions(+), 173 deletions(-) create mode 100644 compiler/rustc_middle/src/sync.rs diff --git a/compiler/rustc_codegen_cranelift/src/driver/aot.rs b/compiler/rustc_codegen_cranelift/src/driver/aot.rs index fc5c634d95709..a688410a935fd 100644 --- a/compiler/rustc_codegen_cranelift/src/driver/aot.rs +++ b/compiler/rustc_codegen_cranelift/src/driver/aot.rs @@ -15,11 +15,11 @@ use rustc_codegen_ssa::base::determine_cgu_reuse; use rustc_codegen_ssa::{CodegenResults, CompiledModule, CrateInfo, ModuleKind}; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; -use rustc_data_structures::sync::{IntoDynSyncSend, par_map}; use rustc_hir::attrs::Linkage as RLinkage; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; use rustc_middle::mir::mono::{CodegenUnit, MonoItem, MonoItemData, Visibility}; +use rustc_middle::sync::{IntoDynSyncSend, par_map}; use rustc_session::Session; use rustc_session::config::{OutputFilenames, OutputType}; diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index 3939f145df881..51c768f17f4d9 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -11,7 +11,6 @@ use rustc_ast::expand::allocator::{ }; use rustc_data_structures::fx::{FxHashMap, FxIndexSet}; use rustc_data_structures::profiling::{get_resident_set_size, print_time_passes_entry}; -use rustc_data_structures::sync::{IntoDynSyncSend, par_map}; use rustc_data_structures::unord::UnordMap; use rustc_hir::attrs::{DebuggerVisualizerType, OptimizeAttr}; use rustc_hir::def_id::{DefId, LOCAL_CRATE}; @@ -26,6 +25,7 @@ use rustc_middle::mir::BinOp; use rustc_middle::mir::interpret::ErrorHandled; use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions}; use rustc_middle::query::Providers; +use rustc_middle::sync::{IntoDynSyncSend, par_map}; use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout}; use rustc_middle::ty::{self, Instance, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; diff --git a/compiler/rustc_data_structures/src/sync.rs b/compiler/rustc_data_structures/src/sync.rs index 31768fe189aef..336b8fd37aee3 100644 --- a/compiler/rustc_data_structures/src/sync.rs +++ b/compiler/rustc_data_structures/src/sync.rs @@ -41,8 +41,7 @@ pub use self::freeze::{FreezeLock, FreezeReadGuard, FreezeWriteGuard}; pub use self::lock::{Lock, LockGuard, Mode}; pub use self::mode::{is_dyn_thread_safe, set_dyn_thread_safe_mode}; pub use self::parallel::{ - broadcast, par_fns, par_for_each_in, par_join, par_map, parallel_guard, spawn, - try_par_for_each_in, + ParallelGuard, broadcast, parallel_guard, spawn, }; pub use self::vec::{AppendOnlyIndexVec, AppendOnlyVec}; pub use self::worker_local::{Registry, WorkerLocal}; diff --git a/compiler/rustc_data_structures/src/sync/parallel.rs b/compiler/rustc_data_structures/src/sync/parallel.rs index 2ab4a7f75b6bd..15039567f6859 100644 --- a/compiler/rustc_data_structures/src/sync/parallel.rs +++ b/compiler/rustc_data_structures/src/sync/parallel.rs @@ -43,19 +43,6 @@ pub fn parallel_guard(f: impl FnOnce(&ParallelGuard) -> R) -> R { ret } -fn serial_join(oper_a: A, oper_b: B) -> (RA, RB) -where - A: FnOnce() -> RA, - B: FnOnce() -> RB, -{ - let (a, b) = parallel_guard(|guard| { - let a = guard.run(oper_a); - let b = guard.run(oper_b); - (a, b) - }); - (a.unwrap(), b.unwrap()) -} - pub fn spawn(func: impl FnOnce() + DynSend + 'static) { if mode::is_dyn_thread_safe() { let func = FromDyn::from(func); @@ -67,152 +54,6 @@ pub fn spawn(func: impl FnOnce() + DynSend + 'static) { } } -/// Runs the functions in parallel. -/// -/// The first function is executed immediately on the current thread. -/// Use that for the longest running function for better scheduling. -pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) { - parallel_guard(|guard: &ParallelGuard| { - if mode::is_dyn_thread_safe() { - let funcs = FromDyn::from(funcs); - rustc_thread_pool::scope(|s| { - let Some((first, rest)) = funcs.into_inner().split_at_mut_checked(1) else { - return; - }; - - // Reverse the order of the later functions since Rayon executes them in reverse - // order when using a single thread. This ensures the execution order matches - // that of a single threaded rustc. - for f in rest.iter_mut().rev() { - let f = FromDyn::from(f); - s.spawn(|_| { - guard.run(|| (f.into_inner())()); - }); - } - - // Run the first function without spawning to - // ensure it executes immediately on this thread. - guard.run(|| first[0]()); - }); - } else { - for f in funcs { - guard.run(|| f()); - } - } - }); -} - -#[inline] -pub fn par_join(oper_a: A, oper_b: B) -> (RA, RB) -where - A: FnOnce() -> RA + DynSend, - B: FnOnce() -> RB + DynSend, -{ - if mode::is_dyn_thread_safe() { - let oper_a = FromDyn::from(oper_a); - let oper_b = FromDyn::from(oper_b); - let (a, b) = parallel_guard(|guard| { - rustc_thread_pool::join( - move || guard.run(move || FromDyn::from(oper_a.into_inner()())), - move || guard.run(move || FromDyn::from(oper_b.into_inner()())), - ) - }); - (a.unwrap().into_inner(), b.unwrap().into_inner()) - } else { - serial_join(oper_a, oper_b) - } -} - -fn par_slice( - items: &mut [I], - guard: &ParallelGuard, - for_each: impl Fn(&mut I) + DynSync + DynSend, -) { - let for_each = FromDyn::from(for_each); - let mut items = for_each.derive(items); - rustc_thread_pool::scope(|s| { - let proof = items.derive(()); - let group_size = std::cmp::max(items.len() / 128, 1); - for group in items.chunks_mut(group_size) { - let group = proof.derive(group); - s.spawn(|_| { - let mut group = group; - for i in group.iter_mut() { - guard.run(|| for_each(i)); - } - }); - } - }); -} - -pub fn par_for_each_in>( - t: T, - for_each: impl Fn(&I) + DynSync + DynSend, -) { - parallel_guard(|guard| { - if mode::is_dyn_thread_safe() { - let mut items: Vec<_> = t.into_iter().collect(); - par_slice(&mut items, guard, |i| for_each(&*i)) - } else { - t.into_iter().for_each(|i| { - guard.run(|| for_each(&i)); - }); - } - }); -} - -/// This runs `for_each` in parallel for each iterator item. If one or more of the -/// `for_each` calls returns `Err`, the function will also return `Err`. The error returned -/// will be non-deterministic, but this is expected to be used with `ErrorGuaranteed` which -/// are all equivalent. -pub fn try_par_for_each_in( - t: T, - for_each: impl Fn(&::Item) -> Result<(), E> + DynSync + DynSend, -) -> Result<(), E> -where - ::Item: DynSend, -{ - parallel_guard(|guard| { - if mode::is_dyn_thread_safe() { - let mut items: Vec<_> = t.into_iter().collect(); - - let error = Mutex::new(None); - - par_slice(&mut items, guard, |i| { - if let Err(err) = for_each(&*i) { - *error.lock() = Some(err); - } - }); - - if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) } - } else { - t.into_iter().filter_map(|i| guard.run(|| for_each(&i))).fold(Ok(()), Result::and) - } - }) -} - -pub fn par_map, R: DynSend, C: FromIterator>( - t: T, - map: impl Fn(I) -> R + DynSync + DynSend, -) -> C { - parallel_guard(|guard| { - if mode::is_dyn_thread_safe() { - let map = FromDyn::from(map); - - let mut items: Vec<(Option, Option)> = - t.into_iter().map(|i| (Some(i), None)).collect(); - - par_slice(&mut items, guard, |i| { - i.1 = Some(map(i.0.take().unwrap())); - }); - - items.into_iter().filter_map(|i| i.1).collect() - } else { - t.into_iter().filter_map(|i| guard.run(|| map(i))).collect() - } - }) -} - pub fn broadcast(op: impl Fn(usize) -> R + DynSync) -> Vec { if mode::is_dyn_thread_safe() { let op = FromDyn::from(op); diff --git a/compiler/rustc_incremental/src/persist/save.rs b/compiler/rustc_incremental/src/persist/save.rs index 996ae162607d3..36ccd62dbfd27 100644 --- a/compiler/rustc_incremental/src/persist/save.rs +++ b/compiler/rustc_incremental/src/persist/save.rs @@ -2,10 +2,10 @@ use std::fs; use std::sync::Arc; use rustc_data_structures::fx::FxIndexMap; -use rustc_data_structures::sync::par_join; use rustc_middle::dep_graph::{ DepGraph, SerializedDepGraph, WorkProduct, WorkProductId, WorkProductMap, }; +use rustc_middle::sync::par_join; use rustc_middle::ty::TyCtxt; use rustc_serialize::Encodable as RustcEncodable; use rustc_serialize::opaque::{FileEncodeResult, FileEncoder}; diff --git a/compiler/rustc_interface/src/passes.rs b/compiler/rustc_interface/src/passes.rs index 15addd2407857..9c58926b4aad2 100644 --- a/compiler/rustc_interface/src/passes.rs +++ b/compiler/rustc_interface/src/passes.rs @@ -11,7 +11,7 @@ use rustc_codegen_ssa::traits::CodegenBackend; use rustc_codegen_ssa::{CodegenResults, CrateInfo}; use rustc_data_structures::indexmap::IndexMap; use rustc_data_structures::steal::Steal; -use rustc_data_structures::sync::{AppendOnlyIndexVec, FreezeLock, WorkerLocal, par_fns}; +use rustc_data_structures::sync::{AppendOnlyIndexVec, FreezeLock, WorkerLocal}; use rustc_data_structures::thousands; use rustc_errors::timings::TimingSection; use rustc_expand::base::{ExtCtxt, LintStoreExpand}; @@ -27,6 +27,7 @@ use rustc_lint::{BufferedEarlyLint, EarlyCheckNode, LintStore, unerased_lint_sto use rustc_metadata::EncodedMetadata; use rustc_metadata::creader::CStore; use rustc_middle::arena::Arena; +use rustc_middle::sync::par_fns; use rustc_middle::ty::{self, RegisteredTools, TyCtxt}; use rustc_middle::util::Providers; use rustc_parse::lexer::StripTokens; diff --git a/compiler/rustc_lint/src/late.rs b/compiler/rustc_lint/src/late.rs index 3cc0d46d8541f..579b64ef88975 100644 --- a/compiler/rustc_lint/src/late.rs +++ b/compiler/rustc_lint/src/late.rs @@ -7,10 +7,10 @@ use std::any::Any; use std::cell::Cell; use rustc_data_structures::stack::ensure_sufficient_stack; -use rustc_data_structures::sync::par_join; use rustc_hir::def_id::{LocalDefId, LocalModDefId}; use rustc_hir::{self as hir, AmbigArg, HirId, intravisit as hir_visit}; use rustc_middle::hir::nested_filter; +use rustc_middle::sync::par_join; use rustc_middle::ty::{self, TyCtxt}; use rustc_session::Session; use rustc_session::lint::LintPass; diff --git a/compiler/rustc_metadata/src/rmeta/encoder.rs b/compiler/rustc_metadata/src/rmeta/encoder.rs index 9b0a114fecea8..46924dd428eb7 100644 --- a/compiler/rustc_metadata/src/rmeta/encoder.rs +++ b/compiler/rustc_metadata/src/rmeta/encoder.rs @@ -7,7 +7,6 @@ use std::sync::Arc; use rustc_data_structures::fx::{FxIndexMap, FxIndexSet}; use rustc_data_structures::memmap::{Mmap, MmapMut}; -use rustc_data_structures::sync::{par_for_each_in, par_join}; use rustc_data_structures::temp_dir::MaybeTempDir; use rustc_data_structures::thousands::usize_with_underscores; use rustc_feature::Features; @@ -21,6 +20,7 @@ use rustc_middle::dep_graph::WorkProductId; use rustc_middle::middle::dependency_format::Linkage; use rustc_middle::mir::interpret; use rustc_middle::query::Providers; +use rustc_middle::sync::{par_for_each_in, par_join}; use rustc_middle::traits::specialization_graph; use rustc_middle::ty::AssocContainer; use rustc_middle::ty::codec::TyEncoder; diff --git a/compiler/rustc_middle/src/hir/map.rs b/compiler/rustc_middle/src/hir/map.rs index 67dd26c8a7d31..a4ca52acc0d1c 100644 --- a/compiler/rustc_middle/src/hir/map.rs +++ b/compiler/rustc_middle/src/hir/map.rs @@ -7,13 +7,14 @@ use rustc_ast::visit::{VisitorResult, walk_list}; use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; use rustc_data_structures::svh::Svh; -use rustc_data_structures::sync::{DynSend, DynSync, par_for_each_in, try_par_for_each_in}; +use rustc_data_structures::sync::{DynSend, DynSync}; use rustc_hir::def::{DefKind, Res}; use rustc_hir::def_id::{DefId, LOCAL_CRATE, LocalDefId, LocalModDefId}; use rustc_hir::definitions::{DefKey, DefPath, DefPathHash}; use rustc_hir::intravisit::Visitor; use rustc_hir::*; use rustc_hir_pretty as pprust_hir; +use rustc_middle::sync::{par_for_each_in, try_par_for_each_in}; use rustc_span::def_id::StableCrateId; use rustc_span::{ErrorGuaranteed, Ident, Span, Symbol, kw, with_metavar_spans}; diff --git a/compiler/rustc_middle/src/hir/mod.rs b/compiler/rustc_middle/src/hir/mod.rs index 82f8eb4bbc4a1..8f17b44951533 100644 --- a/compiler/rustc_middle/src/hir/mod.rs +++ b/compiler/rustc_middle/src/hir/mod.rs @@ -9,7 +9,6 @@ pub mod place; use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::sorted_map::SortedMap; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; -use rustc_data_structures::sync::{DynSend, DynSync, try_par_for_each_in}; use rustc_hir::def::{DefKind, Res}; use rustc_hir::def_id::{DefId, LocalDefId, LocalModDefId}; use rustc_hir::lints::DelayedLint; @@ -18,6 +17,7 @@ use rustc_macros::{Decodable, Encodable, HashStable}; use rustc_span::{ErrorGuaranteed, ExpnId, Span}; use crate::query::Providers; +use crate::sync::{DynSend, DynSync, try_par_for_each_in}; use crate::ty::TyCtxt; /// Gather the LocalDefId for each item-like within a module, including items contained within diff --git a/compiler/rustc_middle/src/lib.rs b/compiler/rustc_middle/src/lib.rs index 615381b37cdb1..5f4a49b60bed6 100644 --- a/compiler/rustc_middle/src/lib.rs +++ b/compiler/rustc_middle/src/lib.rs @@ -78,6 +78,7 @@ pub mod lint; pub mod metadata; pub mod middle; pub mod mir; +pub mod sync; pub mod thir; pub mod traits; pub mod ty; diff --git a/compiler/rustc_middle/src/sync.rs b/compiler/rustc_middle/src/sync.rs new file mode 100644 index 0000000000000..9de7a3c5b13fe --- /dev/null +++ b/compiler/rustc_middle/src/sync.rs @@ -0,0 +1,164 @@ +use parking_lot::Mutex; +pub use rustc_data_structures::marker::{DynSend, DynSync}; +pub use rustc_data_structures::sync::*; + +pub use crate::ty::tls; + +fn serial_join(oper_a: A, oper_b: B) -> (RA, RB) +where + A: FnOnce() -> RA, + B: FnOnce() -> RB, +{ + let (a, b) = parallel_guard(|guard| { + let a = guard.run(oper_a); + let b = guard.run(oper_b); + (a, b) + }); + (a.unwrap(), b.unwrap()) +} + +/// Runs the functions in parallel. +/// +/// The first function is executed immediately on the current thread. +/// Use that for the longest running function for better scheduling. +pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) { + parallel_guard(|guard: &ParallelGuard| { + if is_dyn_thread_safe() { + let funcs = FromDyn::from(funcs); + rustc_thread_pool::scope(|s| { + let Some((first, rest)) = funcs.into_inner().split_at_mut_checked(1) else { + return; + }; + + // Reverse the order of the later functions since Rayon executes them in reverse + // order when using a single thread. This ensures the execution order matches + // that of a single threaded rustc. + for f in rest.iter_mut().rev() { + let f = FromDyn::from(f); + s.spawn(|_| { + guard.run(|| (f.into_inner())()); + }); + } + + // Run the first function without spawning to + // ensure it executes immediately on this thread. + guard.run(|| first[0]()); + }); + } else { + for f in funcs { + guard.run(|| f()); + } + } + }); +} + +#[inline] +pub fn par_join(oper_a: A, oper_b: B) -> (RA, RB) +where + A: FnOnce() -> RA + DynSend, + B: FnOnce() -> RB + DynSend, +{ + if is_dyn_thread_safe() { + let oper_a = FromDyn::from(oper_a); + let oper_b = FromDyn::from(oper_b); + let (a, b) = parallel_guard(|guard| { + rustc_thread_pool::join( + move || guard.run(move || FromDyn::from(oper_a.into_inner()())), + move || guard.run(move || FromDyn::from(oper_b.into_inner()())), + ) + }); + (a.unwrap().into_inner(), b.unwrap().into_inner()) + } else { + serial_join(oper_a, oper_b) + } +} + +fn par_slice( + items: &mut [I], + guard: &ParallelGuard, + for_each: impl Fn(&mut I) + DynSync + DynSend, +) { + let for_each = FromDyn::from(for_each); + let mut items = for_each.derive(items); + rustc_thread_pool::scope(|s| { + let proof = items.derive(()); + let group_size = std::cmp::max(items.len() / 128, 1); + for group in items.chunks_mut(group_size) { + let group = proof.derive(group); + s.spawn(|_| { + let mut group = group; + for i in group.iter_mut() { + guard.run(|| for_each(i)); + } + }); + } + }); +} + +pub fn par_for_each_in>( + t: T, + for_each: impl Fn(&I) + DynSync + DynSend, +) { + parallel_guard(|guard| { + if is_dyn_thread_safe() { + let mut items: Vec<_> = t.into_iter().collect(); + par_slice(&mut items, guard, |i| for_each(&*i)) + } else { + t.into_iter().for_each(|i| { + guard.run(|| for_each(&i)); + }); + } + }); +} + +/// This runs `for_each` in parallel for each iterator item. If one or more of the +/// `for_each` calls returns `Err`, the function will also return `Err`. The error returned +/// will be non-deterministic, but this is expected to be used with `ErrorGuaranteed` which +/// are all equivalent. +pub fn try_par_for_each_in( + t: T, + for_each: impl Fn(&::Item) -> Result<(), E> + DynSync + DynSend, +) -> Result<(), E> +where + ::Item: DynSend, +{ + parallel_guard(|guard| { + if is_dyn_thread_safe() { + let mut items: Vec<_> = t.into_iter().collect(); + + let error = Mutex::new(None); + + par_slice(&mut items, guard, |i| { + if let Err(err) = for_each(&*i) { + *error.lock() = Some(err); + } + }); + + if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) } + } else { + t.into_iter().filter_map(|i| guard.run(|| for_each(&i))).fold(Ok(()), Result::and) + } + }) +} + +pub fn par_map, R: DynSend, C: FromIterator>( + t: T, + map: impl Fn(I) -> R + DynSync + DynSend, +) -> C { + parallel_guard(|guard| { + if is_dyn_thread_safe() { + let map = FromDyn::from(map); + + let mut items: Vec<(Option, Option)> = + t.into_iter().map(|i| (Some(i), None)).collect(); + + par_slice(&mut items, guard, |i| { + i.1 = Some(map(i.0.take().unwrap())); + }); + + items.into_iter().filter_map(|i| i.1).collect() + } else { + t.into_iter().filter_map(|i| guard.run(|| map(i))).collect() + } + }) +} diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 4f6e2cc005160..e5b33ba64b560 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -211,7 +211,6 @@ use std::cell::OnceCell; use std::ops::ControlFlow; use rustc_data_structures::fx::FxIndexMap; -use rustc_data_structures::sync::{MTLock, par_for_each_in}; use rustc_data_structures::unord::{UnordMap, UnordSet}; use rustc_hir as hir; use rustc_hir::attrs::InlineAttr; @@ -227,6 +226,7 @@ use rustc_middle::mir::mono::{ use rustc_middle::mir::visit::Visitor as MirVisitor; use rustc_middle::mir::{self, Body, Location, MentionedItem, traversal}; use rustc_middle::query::TyCtxtAt; +use rustc_middle::sync::{MTLock, par_for_each_in}; use rustc_middle::ty::adjustment::{CustomCoerceUnsized, PointerCoercion}; use rustc_middle::ty::layout::ValidityRequirement; use rustc_middle::ty::{ diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index d8f4e01945075..dc72e74be2305 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -99,7 +99,6 @@ use std::io::Write; use std::path::{Path, PathBuf}; use rustc_data_structures::fx::{FxIndexMap, FxIndexSet}; -use rustc_data_structures::sync::par_join; use rustc_data_structures::unord::{UnordMap, UnordSet}; use rustc_hir::LangItem; use rustc_hir::attrs::{InlineAttr, Linkage}; @@ -113,6 +112,7 @@ use rustc_middle::mir::mono::{ CodegenUnit, CodegenUnitNameBuilder, InstantiationMode, MonoItem, MonoItemData, MonoItemPartitions, Visibility, }; +use rustc_middle::sync::par_join; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; use rustc_middle::ty::{self, InstanceKind, TyCtxt}; use rustc_middle::util::Providers; diff --git a/src/tools/miri/src/bin/miri.rs b/src/tools/miri/src/bin/miri.rs index 14528759472c8..4db06f22e2143 100644 --- a/src/tools/miri/src/bin/miri.rs +++ b/src/tools/miri/src/bin/miri.rs @@ -9,7 +9,6 @@ // The rustc crates we need extern crate rustc_abi; extern crate rustc_codegen_ssa; -extern crate rustc_data_structures; extern crate rustc_driver; extern crate rustc_hir; extern crate rustc_hir_analysis; @@ -51,7 +50,6 @@ use miri::{ }; use rustc_abi::ExternAbi; use rustc_codegen_ssa::traits::CodegenBackend; -use rustc_data_structures::sync::{self, DynSync}; use rustc_driver::Compilation; use rustc_hir::def_id::LOCAL_CRATE; use rustc_hir::{self as hir, Node}; @@ -64,6 +62,7 @@ use rustc_middle::middle::exported_symbols::{ ExportedSymbol, SymbolExportInfo, SymbolExportKind, SymbolExportLevel, }; use rustc_middle::query::LocalCrate; +use rustc_middle::sync::{self, DynSync}; use rustc_middle::traits::{ObligationCause, ObligationCauseCode}; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_session::EarlyDiagCtxt; From 1d02e902c20a0612847b77b713105df4eb5a7064 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Fri, 6 Feb 2026 19:24:10 +0300 Subject: [PATCH 04/11] Implement TreeNodeIndex tracking for parallel interfaces --- compiler/rustc_data_structures/src/sync.rs | 4 +- compiler/rustc_middle/src/sync.rs | 85 +++++++++++++++++++--- compiler/rustc_query_impl/src/job.rs | 3 +- compiler/rustc_query_impl/src/plumbing.rs | 3 +- 4 files changed, 78 insertions(+), 17 deletions(-) diff --git a/compiler/rustc_data_structures/src/sync.rs b/compiler/rustc_data_structures/src/sync.rs index 336b8fd37aee3..fa3f61fe05ba3 100644 --- a/compiler/rustc_data_structures/src/sync.rs +++ b/compiler/rustc_data_structures/src/sync.rs @@ -40,9 +40,7 @@ pub use self::freeze::{FreezeLock, FreezeReadGuard, FreezeWriteGuard}; #[doc(no_inline)] pub use self::lock::{Lock, LockGuard, Mode}; pub use self::mode::{is_dyn_thread_safe, set_dyn_thread_safe_mode}; -pub use self::parallel::{ - ParallelGuard, broadcast, parallel_guard, spawn, -}; +pub use self::parallel::{ParallelGuard, broadcast, parallel_guard, spawn}; pub use self::vec::{AppendOnlyIndexVec, AppendOnlyVec}; pub use self::worker_local::{Registry, WorkerLocal}; pub use crate::marker::*; diff --git a/compiler/rustc_middle/src/sync.rs b/compiler/rustc_middle/src/sync.rs index 9de7a3c5b13fe..63929a8d87b41 100644 --- a/compiler/rustc_middle/src/sync.rs +++ b/compiler/rustc_middle/src/sync.rs @@ -2,6 +2,7 @@ use parking_lot::Mutex; pub use rustc_data_structures::marker::{DynSend, DynSync}; pub use rustc_data_structures::sync::*; +use crate::query::QueryInclusion; pub use crate::ty::tls; fn serial_join(oper_a: A, oper_b: B) -> (RA, RB) @@ -24,6 +25,7 @@ where pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) { parallel_guard(|guard: &ParallelGuard| { if is_dyn_thread_safe() { + let func_count = funcs.len().try_into().unwrap(); let funcs = FromDyn::from(funcs); rustc_thread_pool::scope(|s| { let Some((first, rest)) = funcs.into_inner().split_at_mut_checked(1) else { @@ -33,16 +35,18 @@ pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) { // Reverse the order of the later functions since Rayon executes them in reverse // order when using a single thread. This ensures the execution order matches // that of a single threaded rustc. - for f in rest.iter_mut().rev() { + for (i, f) in rest.iter_mut().enumerate().rev() { let f = FromDyn::from(f); - s.spawn(|_| { - guard.run(|| (f.into_inner())()); + s.spawn(move |_| { + branch_context((i + 1).try_into().unwrap(), func_count, || { + guard.run(|| (f.into_inner())()) + }); }); } // Run the first function without spawning to // ensure it executes immediately on this thread. - guard.run(|| first[0]()); + branch_context(0, func_count, || guard.run(|| first[0]())); }); } else { for f in funcs { @@ -62,7 +66,7 @@ where let oper_a = FromDyn::from(oper_a); let oper_b = FromDyn::from(oper_b); let (a, b) = parallel_guard(|guard| { - rustc_thread_pool::join( + raw_branched_join( move || guard.run(move || FromDyn::from(oper_a.into_inner()())), move || guard.run(move || FromDyn::from(oper_b.into_inner()())), ) @@ -78,20 +82,50 @@ fn par_slice( guard: &ParallelGuard, for_each: impl Fn(&mut I) + DynSync + DynSend, ) { + match items { + [] => return, + [item] => { + guard.run(|| for_each(item)); + return; + } + _ => (), + } + let for_each = FromDyn::from(for_each); let mut items = for_each.derive(items); rustc_thread_pool::scope(|s| { + let for_each = &for_each; let proof = items.derive(()); - let group_size = std::cmp::max(items.len() / 128, 1); - for group in items.chunks_mut(group_size) { + + const MAX_GROUP_COUNT: usize = 128; + let group_size = items.len().div_ceil(MAX_GROUP_COUNT); + let mut groups = items.chunks_mut(group_size).enumerate(); + let group_count = groups.len().try_into().unwrap(); + + let Some((_, first_group)) = groups.next() else { return }; + + // Reverse the order of the later functions since Rayon executes them in reverse + // order when using a single thread. This ensures the execution order matches + // that of a single threaded rustc. + for (i, group) in groups.rev() { let group = proof.derive(group); - s.spawn(|_| { - let mut group = group; - for i in group.iter_mut() { - guard.run(|| for_each(i)); - } + s.spawn(move |_| { + branch_context(i.try_into().unwrap(), group_count, || { + let mut group = group; + for i in group.iter_mut() { + guard.run(|| for_each(i)); + } + }) }); } + + // Run the first function without spawning to + // ensure it executes immediately on this thread. + branch_context(0, group_count, || { + for i in first_group.iter_mut() { + guard.run(|| for_each(i)); + } + }); }); } @@ -162,3 +196,30 @@ pub fn par_map, R: DynSend, C: FromIterato } }) } + +fn raw_branched_join(oper_a: A, oper_b: B) -> (RA, RB) +where + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, +{ + rustc_thread_pool::join(|| branch_context(0, 2, oper_a), || branch_context(1, 2, oper_b)) +} + +fn branch_context(branch_num: u64, branch_space: u64, f: F) -> R +where + F: FnOnce() -> R, +{ + tls::with_context_opt(|icx| { + if let Some(icx) = icx + && let Some(QueryInclusion { id, branch }) = icx.query + { + let icx = tls::ImplicitCtxt { + query: Some(QueryInclusion { id, branch: branch.branch(branch_num, branch_space) }), + ..*icx + }; + tls::enter_context(&icx, f) + } else { + f() + } + }) +} diff --git a/compiler/rustc_query_impl/src/job.rs b/compiler/rustc_query_impl/src/job.rs index 0a1e984c560b5..370ce0be2afbe 100644 --- a/compiler/rustc_query_impl/src/job.rs +++ b/compiler/rustc_query_impl/src/job.rs @@ -7,7 +7,8 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_errors::{Diag, DiagCtxtHandle}; use rustc_hir::def::DefKind; use rustc_middle::query::{ - CycleError, QueryInclusion, QueryInfo, QueryJob, QueryJobId, QueryLatch, QueryStackDeferred, QueryStackFrame, QueryWaiter + CycleError, QueryInclusion, QueryInfo, QueryJob, QueryJobId, QueryLatch, QueryStackDeferred, + QueryStackFrame, QueryWaiter, }; use rustc_middle::ty::TyCtxt; use rustc_session::Session; diff --git a/compiler/rustc_query_impl/src/plumbing.rs b/compiler/rustc_query_impl/src/plumbing.rs index 4d9ac6fc0069a..ff16e6308c1cf 100644 --- a/compiler/rustc_query_impl/src/plumbing.rs +++ b/compiler/rustc_query_impl/src/plumbing.rs @@ -19,7 +19,8 @@ use rustc_middle::query::on_disk_cache::{ }; use rustc_middle::query::plumbing::QueryVTable; use rustc_middle::query::{ - Key, QueryCache, QueryInclusion, QueryJobId, QueryStackDeferred, QueryStackFrame, QueryStackFrameExtra + Key, QueryCache, QueryInclusion, QueryJobId, QueryStackDeferred, QueryStackFrame, + QueryStackFrameExtra, }; use rustc_middle::ty::codec::TyEncoder; use rustc_middle::ty::print::with_reduced_queries; From f46682827d0de0227b2d8df16fba2335c037cb31 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Fri, 6 Feb 2026 19:26:15 +0300 Subject: [PATCH 05/11] Replace beack_query_cycles implementation with new --- compiler/rustc_middle/src/query/job.rs | 4 +- compiler/rustc_query_impl/src/execution.rs | 5 +- compiler/rustc_query_impl/src/job.rs | 394 +++++++-------------- 3 files changed, 136 insertions(+), 267 deletions(-) diff --git a/compiler/rustc_middle/src/query/job.rs b/compiler/rustc_middle/src/query/job.rs index e98ca70db5d13..8a499935b525c 100644 --- a/compiler/rustc_middle/src/query/job.rs +++ b/compiler/rustc_middle/src/query/job.rs @@ -80,7 +80,6 @@ pub struct QueryInclusion { pub struct QueryWaiter<'tcx> { pub query: Option, pub condvar: Condvar, - pub span: Span, pub cycle: Mutex>>>, } @@ -107,10 +106,9 @@ impl<'tcx> QueryLatch<'tcx> { &self, tcx: TyCtxt<'tcx>, query: Option, - span: Span, ) -> Result<(), CycleError>> { let waiter = - Arc::new(QueryWaiter { query, span, cycle: Mutex::new(None), condvar: Condvar::new() }); + Arc::new(QueryWaiter { query, cycle: Mutex::new(None), condvar: Condvar::new() }); self.wait_on_inner(tcx, &waiter); // FIXME: Get rid of this lock. We have ownership of the QueryWaiter // although another thread may still have a Arc reference so we cannot diff --git a/compiler/rustc_query_impl/src/execution.rs b/compiler/rustc_query_impl/src/execution.rs index 22992d094afa1..68eac4d46e32b 100644 --- a/compiler/rustc_query_impl/src/execution.rs +++ b/compiler/rustc_query_impl/src/execution.rs @@ -229,7 +229,6 @@ fn cycle_error<'tcx, C: QueryCache>( fn wait_for_query<'tcx, C: QueryCache>( query: &'tcx QueryVTable<'tcx, C>, tcx: TyCtxt<'tcx>, - span: Span, key: C::Key, latch: QueryLatch<'tcx>, current: Option, @@ -241,7 +240,7 @@ fn wait_for_query<'tcx, C: QueryCache>( // With parallel queries we might just have to wait on some other // thread. - let result = latch.wait_on(tcx, current, span); + let result = latch.wait_on(tcx, current); match result { Ok(()) => { @@ -321,7 +320,7 @@ fn try_execute_query<'tcx, C: QueryCache, const INCR: bool>( // Only call `wait_for_query` if we're using a Rayon thread pool // as it will attempt to mark the worker thread as blocked. - return wait_for_query(query, tcx, span, key, latch, current_inclusion); + return wait_for_query(query, tcx, key, latch, current_inclusion); } let id = job.id; diff --git a/compiler/rustc_query_impl/src/job.rs b/compiler/rustc_query_impl/src/job.rs index 370ce0be2afbe..4178eb4174f55 100644 --- a/compiler/rustc_query_impl/src/job.rs +++ b/compiler/rustc_query_impl/src/job.rs @@ -1,18 +1,17 @@ +use std::collections::BTreeMap; use std::io::Write; -use std::iter; -use std::ops::ControlFlow; -use std::sync::Arc; -use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::indexmap::{self, IndexMap}; +use rustc_data_structures::tree_node_index::TreeNodeIndex; use rustc_errors::{Diag, DiagCtxtHandle}; use rustc_hir::def::DefKind; use rustc_middle::query::{ - CycleError, QueryInclusion, QueryInfo, QueryJob, QueryJobId, QueryLatch, QueryStackDeferred, - QueryStackFrame, QueryWaiter, + CycleError, QueryInfo, QueryJob, QueryJobId, QueryStackDeferred, QueryStackFrame, }; use rustc_middle::ty::TyCtxt; use rustc_session::Session; -use rustc_span::{DUMMY_SP, Span}; +use rustc_span::Span; use crate::plumbing::collect_active_jobs_from_all_queries; @@ -34,18 +33,6 @@ impl<'tcx> QueryJobMap<'tcx> { fn frame_of(&self, id: QueryJobId) -> &QueryStackFrame> { &self.map[&id].frame } - - fn span_of(&self, id: QueryJobId) -> Span { - self.map[&id].job.span - } - - fn parent_of(&self, id: QueryJobId) -> Option { - self.map[&id].job.parent - } - - fn latch_of(&self, id: QueryJobId) -> Option<&QueryLatch<'tcx>> { - self.map[&id].job.latch.as_ref() - } } #[derive(Clone, Debug)] @@ -112,262 +99,147 @@ pub(crate) fn find_dep_kind_root<'tcx>( last_layout } -/// A resumable waiter of a query. The usize is the index into waiters in the query's latch -type Waiter = (QueryJobId, usize); - -/// Visits all the non-resumable and resumable waiters of a query. -/// Only waiters in a query are visited. -/// `visit` is called for every waiter and is passed a query waiting on `query` -/// and a span indicating the reason the query waited on `query`. -/// If `visit` returns `Break`, this function also returns `Break`, -/// and if all `visit` calls returns `Continue` it also returns `Continue`. -/// For visits of non-resumable waiters it returns the return value of `visit`. -/// For visits of resumable waiters it returns information required to resume that waiter. -fn visit_waiters<'tcx>( - job_map: &QueryJobMap<'tcx>, - query: QueryJobId, - mut visit: impl FnMut(Span, QueryJobId) -> ControlFlow>, -) -> ControlFlow> { - // Visit the parent query which is a non-resumable waiter since it's on the same stack - if let Some(parent) = job_map.parent_of(query) { - visit(job_map.span_of(query), parent.id)?; - } - - // Visit the explicit waiters which use condvars and are resumable - if let Some(latch) = job_map.latch_of(query) { - for (i, waiter) in latch.info.lock().waiters.iter().enumerate() { - if let Some(waiter_query) = waiter.query { - // Return a value which indicates that this waiter can be resumed - visit(waiter.span, waiter_query.id).map_break(|_| Some((query, i)))?; - } +/// Breaks left-most cycle on a left-most query in order of a single-threaded execution. +/// +/// Order of queries is tracked using [`TreeNodeIndex`] in [`rustc_middle::sync`]. +/// This function uses ordered depth-first search from a single root query down to the first +/// duplicate query. +/// It doesn't distinguish between a query wait and a query execution, so both are just query calls. +/// As such some queries may have two or more parent query calls too. +/// +/// But while it breaks on the same query as with a single thread, +/// we are not guaranteed to break on the same query **call**. +/// This is good enough, as the difference is irrelevant to query cycle recovery code. +/// Every other difference AFAIK is tolerable. +/// Potential different query result values are fine as either ill-defined due to cycles or +/// as they preserve the same query result value between different query calls. +/// +/// To illustrate how it work say we have a query cycle: +/// +/// ```text +/// a() -> b() -> a() +/// ``` +/// +/// and a program `join(|| a(), || b())`. +/// On a single-thread it triggers cycle recovery on a `a()` call within `b()` query. +/// However consider a multi-threaded execution: +/// +/// ```text +/// thread 1: waits on a() +/// thread 2: b() -> a() -> waits on b() +/// ``` +/// +/// Similar to single-threaded execution, we have to resume wait on `a()`. +/// However this time it could only be done to a *different query call*, the one inside of `join`. +/// Then we resume until thread 1 blocks in join on a `b()` task. +/// +/// ```text +/// thread 1: indirectly waits on b() +/// thread 2: b() -> a() -> waits on b() +/// ``` +/// +/// Now the left-most query to break is `b()` so we resume thread 2. +/// This difference in behavior is strictly more tolerable than the undeterministic cycle breaking. +#[allow(rustc::potential_query_instability)] +pub fn break_query_cycles<'tcx>( + query_map: QueryJobMap<'tcx>, + registry: &rustc_thread_pool::Registry, +) { + let mut root_query = None; + for (&query, info) in &query_map.map { + if info.job.parent.is_none() { + assert!(root_query.is_none(), "found multiple threads without start"); + root_query = Some(query); } } + let root_query = root_query.expect("no root query was found"); - ControlFlow::Continue(()) -} - -/// Look for query cycles by doing a depth first search starting at `query`. -/// `span` is the reason for the `query` to execute. This is initially DUMMY_SP. -/// If a cycle is detected, this initial value is replaced with the span causing -/// the cycle. -fn cycle_check<'tcx>( - job_map: &QueryJobMap<'tcx>, - query: QueryJobId, - span: Span, - stack: &mut Vec<(Span, QueryJobId)>, - visited: &mut FxHashSet, -) -> ControlFlow> { - if !visited.insert(query) { - return if let Some(p) = stack.iter().position(|q| q.1 == query) { - // We detected a query cycle, fix up the initial span and return Some - - // Remove previous stack entries - stack.drain(0..p); - // Replace the span for the first query with the cycle cause - stack[0].0 = span; - ControlFlow::Break(None) - } else { - ControlFlow::Continue(()) + let mut subqueries = FxHashMap::<_, BTreeMap>::default(); + for query in query_map.map.values() { + let Some(inclusion) = &query.job.parent else { + continue; }; + let old = subqueries + .entry(inclusion.id) + .or_default() + .insert(inclusion.branch, (query.job.id, usize::MAX)); + assert!(old.is_none()); } - // Query marked as visited is added it to the stack - stack.push((span, query)); - - // Visit all the waiters - let r = visit_waiters(job_map, query, |span, successor| { - cycle_check(job_map, successor, span, stack, visited) - }); - - // Remove the entry in our stack if we didn't find a cycle - if r.is_continue() { - stack.pop(); - } - - r -} - -/// Finds out if there's a path to the compiler root (aka. code which isn't in a query) -/// from `query` without going through any of the queries in `visited`. -/// This is achieved with a depth first search. -fn connected_to_root<'tcx>( - job_map: &QueryJobMap<'tcx>, - query: QueryJobId, - visited: &mut FxHashSet, -) -> ControlFlow> { - // We already visited this or we're deliberately ignoring it - if !visited.insert(query) { - return ControlFlow::Continue(()); - } - - // This query is connected to the root (it has no query parent), return true - if job_map.parent_of(query).is_none() { - return ControlFlow::Break(None); - } - - visit_waiters(job_map, query, |_, successor| connected_to_root(job_map, successor, visited)) -} - -/// Looks for query cycles starting from the last query in `jobs`. -/// If a cycle is found, all queries in the cycle is removed from `jobs` and -/// the function return true. -/// If a cycle was not found, the starting query is removed from `jobs` and -/// the function returns false. -fn remove_cycle<'tcx>( - job_map: &QueryJobMap<'tcx>, - jobs: &mut Vec, - wakelist: &mut Vec>>, -) -> bool { - let mut visited = FxHashSet::default(); - let mut stack = Vec::new(); - // Look for a cycle starting with the last query in `jobs` - if let ControlFlow::Break(waiter) = - cycle_check(job_map, jobs.pop().unwrap(), DUMMY_SP, &mut stack, &mut visited) - { - // The stack is a vector of pairs of spans and queries; reverse it so that - // the earlier entries require later entries - let (mut spans, queries): (Vec<_>, Vec<_>) = stack.into_iter().rev().unzip(); - - // Shift the spans so that queries are matched with the span for their waitee - spans.rotate_right(1); - - // Zip them back together - let mut stack: Vec<_> = iter::zip(spans, queries).collect(); - - // Remove the queries in our cycle from the list of jobs to look at - for r in &stack { - if let Some(pos) = jobs.iter().position(|j| j == &r.1) { - jobs.remove(pos); - } - } - - struct EntryPoint { - query_in_cycle: QueryJobId, - waiter: Option<(Span, QueryJobId)>, + for query in query_map.map.values() { + let Some(latch) = &query.job.latch else { + continue; + }; + // Latch mutexes should be at least about to unlock as we do not hold it anywhere too long + let lock = latch.info.lock(); + assert!(!lock.complete); + for (waiter_idx, waiter) in lock.waiters.iter().enumerate() { + let inclusion = waiter.query.expect("cannot wait on a root query"); + let old = subqueries + .entry(inclusion.id) + .or_default() + .insert(inclusion.branch, (query.job.id, waiter_idx)); + assert!(old.is_none()); } + } - // Find the queries in the cycle which are - // connected to queries outside the cycle - let entry_points = stack - .iter() - .filter_map(|&(_, query_in_cycle)| { - if job_map.parent_of(query_in_cycle).is_none() { - // This query is connected to the root (it has no query parent) - Some(EntryPoint { query_in_cycle, waiter: None }) - } else { - let mut waiter_on_cycle = None; - // Find a direct waiter who leads to the root - let _ = visit_waiters(job_map, query_in_cycle, |span, waiter| { - // Mark all the other queries in the cycle as already visited - let mut visited = FxHashSet::from_iter(stack.iter().map(|q| q.1)); - - if connected_to_root(job_map, waiter, &mut visited).is_break() { - waiter_on_cycle = Some((span, waiter)); - ControlFlow::Break(None) - } else { - ControlFlow::Continue(()) - } - }); - - waiter_on_cycle.map(|waiter_on_cycle| EntryPoint { - query_in_cycle, - waiter: Some(waiter_on_cycle), - }) - } + let mut visited = IndexMap::new(); + let mut last_usage = None; + let mut last_waiter_idx = usize::MAX; + let mut current = root_query; + while let indexmap::map::Entry::Vacant(entry) = visited.entry(current) { + entry.insert((last_usage, last_waiter_idx)); + last_usage = Some(current); + (current, last_waiter_idx) = *subqueries + .get(¤t) + .unwrap_or_else(|| { + panic!( + "deadlock detected as we're unable to find a query cycle to break\n\ + current query map:\n{:#?}", + query_map + ) }) - .collect::>(); - - // Pick an entry point, preferring ones with waiters - let entry_point = entry_points - .iter() - .find(|entry_point| entry_point.waiter.is_some()) - .unwrap_or(&entry_points[0]); - - // Shift the stack so that our entry point is first - let entry_point_pos = - stack.iter().position(|(_, query)| *query == entry_point.query_in_cycle); - if let Some(pos) = entry_point_pos { - stack.rotate_left(pos); - } - - let usage = entry_point.waiter.map(|(span, job)| (span, job_map.frame_of(job).clone())); - - // Create the cycle error - let error = CycleError { - usage, - cycle: stack - .iter() - .map(|&(span, job)| QueryInfo { span, frame: job_map.frame_of(job).clone() }) - .collect(), - }; - - // We unwrap `waiter` here since there must always be one - // edge which is resumable / waited using a query latch - let (waitee_query, waiter_idx) = waiter.unwrap(); - - // Extract the waiter we want to resume - let waiter = job_map.latch_of(waitee_query).unwrap().extract_waiter(waiter_idx); - - // Set the cycle error so it will be picked up when resumed - *waiter.cycle.lock() = Some(error); - - // Put the waiter on the list of things to resume - wakelist.push(waiter); - - true - } else { - false + .first_key_value() + .unwrap() + .1; } -} - -/// Detects query cycles by using depth first search over all active query jobs. -/// If a query cycle is found it will break the cycle by finding an edge which -/// uses a query latch and then resuming that waiter. -/// There may be multiple cycles involved in a deadlock, so this searches -/// all active queries for cycles before finally resuming all the waiters at once. -pub fn break_query_cycles<'tcx>( - job_map: QueryJobMap<'tcx>, - registry: &rustc_thread_pool::Registry, -) { - let mut wakelist = Vec::new(); - // It is OK per the comments: - // - https://github.com/rust-lang/rust/pull/131200#issuecomment-2798854932 - // - https://github.com/rust-lang/rust/pull/131200#issuecomment-2798866392 - #[allow(rustc::potential_query_instability)] - let mut jobs: Vec = job_map.map.keys().copied().collect(); - - let mut found_cycle = false; - - while jobs.len() > 0 { - if remove_cycle(&job_map, &mut jobs, &mut wakelist) { - found_cycle = true; + let usage = visited[¤t].0; + let mut iter = visited.keys().rev(); + let mut cycle = Vec::new(); + loop { + let query_id = *iter.next().unwrap(); + let query = &query_map.map[&query_id]; + cycle.push(QueryInfo { span: query.job.span, frame: query.frame.clone() }); + if query_id == current { + break; } } - // Check that a cycle was found. It is possible for a deadlock to occur without - // a query cycle if a query which can be waited on uses Rayon to do multithreading - // internally. Such a query (X) may be executing on 2 threads (A and B) and A may - // wait using Rayon on B. Rayon may then switch to executing another query (Y) - // which in turn will wait on X causing a deadlock. We have a false dependency from - // X to Y due to Rayon waiting and a true dependency from Y to X. The algorithm here - // only considers the true dependency and won't detect a cycle. - if !found_cycle { - panic!( - "deadlock detected as we're unable to find a query cycle to break\n\ - current query map:\n{job_map:#?}", - ); - } - - // Mark all the thread we're about to wake up as unblocked. This needs to be done before - // we wake the threads up as otherwise Rayon could detect a deadlock if a thread we - // resumed fell asleep and this thread had yet to mark the remaining threads as unblocked. - for _ in 0..wakelist.len() { - rustc_thread_pool::mark_unblocked(registry); - } + cycle.reverse(); + let cycle_error = CycleError { + usage: usage.map(|id| { + let query = &query_map.map[&id]; + (query.job.span, query.frame.clone()) + }), + cycle, + }; - for waiter in wakelist.into_iter() { - waiter.condvar.notify_one(); - } + let (waited_on, waiter_idx) = if last_waiter_idx != usize::MAX { + (current, last_waiter_idx) + } else { + let (&waited_on, &(_, waiter_idx)) = + visited.iter().rev().find(|(_, (_, waiter_idx))| *waiter_idx != usize::MAX).unwrap(); + (waited_on, waiter_idx) + }; + let waited_on = &query_map.map[&waited_on]; + let latch = waited_on.job.latch.as_ref().unwrap(); + let mut latch_info_lock = latch.info.try_lock().unwrap(); + let waiter = latch_info_lock.waiters.remove(waiter_idx); + let mut cycle_lock = waiter.cycle.try_lock().unwrap(); + assert!(cycle_lock.is_none()); + *cycle_lock = Some(cycle_error); + rustc_thread_pool::mark_unblocked(registry); + waiter.condvar.notify_one(); } pub fn print_query_stack<'tcx>( From 5f691612e7475cb26569d5bc24015398a9977f79 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Wed, 25 Feb 2026 17:36:45 +0300 Subject: [PATCH 06/11] Add TreeNodeIndex unit tests --- .../src/tree_node_index.rs | 24 +++++----- .../src/tree_node_index/tests.rs | 45 +++++++++++++++++++ 2 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 compiler/rustc_data_structures/src/tree_node_index/tests.rs diff --git a/compiler/rustc_data_structures/src/tree_node_index.rs b/compiler/rustc_data_structures/src/tree_node_index.rs index dfe03f9776cab..b1bf721bf2a8c 100644 --- a/compiler/rustc_data_structures/src/tree_node_index.rs +++ b/compiler/rustc_data_structures/src/tree_node_index.rs @@ -29,7 +29,7 @@ use std::fmt::Display; /// This is done in query cycle handling code to determine **intended** first task for a single- /// threaded compiler front-end to execute even while multi-threaded. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct TreeNodeIndex(pub u64); +pub struct TreeNodeIndex(u64); impl TreeNodeIndex { pub const fn root() -> Self { @@ -43,7 +43,7 @@ impl TreeNodeIndex { Ok(TreeNodeIndex( self.0 & !(1 << trailing_zeros) | (1 << allocated_shift) - | (branch_idx << (allocated_shift + 1)), + | branch_idx.unbounded_shl(allocated_shift + 1), )) } @@ -53,17 +53,16 @@ impl TreeNodeIndex { branch_idx < branch_num, "branch_idx = {branch_idx} should be less than branch_num = {branch_num}" ); - // floor(log2(n - 1)) + 1 == ceil(log2(n)) - self.try_bits_branch(branch_idx, (branch_num - 1).checked_ilog2().map_or(0, |b| b + 1)) - .unwrap() + // `branch_num != 0` per debug assertion above + let bits = ceil_ilog2(branch_num); + self.try_bits_branch(branch_idx, bits).unwrap() } +} - pub fn try_concat(self, then: Self) -> Result { - let trailing_zeros = then.0.trailing_zeros(); - let branch_num = then.0.wrapping_shr(trailing_zeros + 1); - let bits = u64::BITS - trailing_zeros; - self.try_bits_branch(branch_num, bits) - } +#[inline] +fn ceil_ilog2(branch_num: u64) -> u32 { + // floor(log2(n - 1)) + 1 == ceil(log2(n)) + (branch_num - 1).checked_ilog2().map_or(0, |b| b + 1) } /// Error for exhausting free bits @@ -84,3 +83,6 @@ impl Default for TreeNodeIndex { TreeNodeIndex::root() } } + +#[cfg(test)] +mod tests; diff --git a/compiler/rustc_data_structures/src/tree_node_index/tests.rs b/compiler/rustc_data_structures/src/tree_node_index/tests.rs new file mode 100644 index 0000000000000..f2675a7c21cbb --- /dev/null +++ b/compiler/rustc_data_structures/src/tree_node_index/tests.rs @@ -0,0 +1,45 @@ +use crate::tree_node_index::TreeNodeIndex; + +#[test] +fn up_to_16() { + for n in 1..128 { + for i in 0..n { + TreeNodeIndex::root().branch(i, n).branch(n - i - 1, n); + } + } +} + +#[test] +fn ceil_log2() { + const EVALUATION_TABLE: [(u64, u32); 9] = + [(1, 0), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 3), (8, 3), (u64::MAX, 64)]; + for &(x, y) in &EVALUATION_TABLE { + let r = super::ceil_ilog2(x); + assert!(r == y, "ceil_ilog2({x}) == {r} != {y}"); + } +} + +#[test] +fn some_cases() { + let mut tni = TreeNodeIndex::root(); + tni = tni.branch(0xDEAD, 0xFADE); + assert_eq!(tni.0, 0xDEAD8000_00000000); + tni = tni.branch(0xBEEF, 0xCCCC); + assert_eq!(tni.0, 0xDEADBEEF_80000000); + tni = tni.branch(1, 2); + assert_eq!(tni.0, 0xDEADBEEF_C0000000); + tni = tni.branch(0, 2); + assert_eq!(tni.0, 0xDEADBEEF_A0000000); + tni = tni.branch(3, 4); + assert_eq!(tni.0, 0xDEADBEEF_B8000000); + tni = tni.branch(0xAAAAAA, 0xBBBBBB); + assert_eq!(tni.0, 0xDEADBEEF_BAAAAAA8); +} + +#[test] +fn edge_cases() { + const ROOT: TreeNodeIndex = TreeNodeIndex::root(); + assert_eq!(ROOT.branch(0, 1), TreeNodeIndex::root()); + assert_eq!(ROOT.branch(u64::MAX >> 1, 1 << 63).0, u64::MAX); + assert_eq!(ROOT.branch(0, 1 << 63).0, 1); +} From 6c9954ea9e3c6c9cc95f194a89fd771b25c8df4b Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Wed, 25 Feb 2026 18:06:50 +0300 Subject: [PATCH 07/11] Optimize TreeNodeIndex with wrapping operations (checked in godbolt.org) --- .../rustc_data_structures/src/tree_node_index.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/compiler/rustc_data_structures/src/tree_node_index.rs b/compiler/rustc_data_structures/src/tree_node_index.rs index b1bf721bf2a8c..6199f7ec20453 100644 --- a/compiler/rustc_data_structures/src/tree_node_index.rs +++ b/compiler/rustc_data_structures/src/tree_node_index.rs @@ -40,10 +40,13 @@ impl TreeNodeIndex { fn try_bits_branch(self, branch_idx: u64, bits: u32) -> Result { let trailing_zeros = self.0.trailing_zeros(); let allocated_shift = trailing_zeros.checked_sub(bits).ok_or(BranchingError(()))?; + // Using wrapping operations for optimization, as edge cases are unreachable: + // - `trailing_zeros < 64` as we are guaranteed at least one bit is set + // - `allocated_shift == trailing_zeros - bits <= trailing_zeros < 64` Ok(TreeNodeIndex( - self.0 & !(1 << trailing_zeros) - | (1 << allocated_shift) - | branch_idx.unbounded_shl(allocated_shift + 1), + self.0 & !u64::wrapping_shl(1, trailing_zeros) + | u64::wrapping_shl(1, allocated_shift) + | branch_idx.unbounded_shl(allocated_shift.wrapping_add(1)), )) } @@ -61,8 +64,9 @@ impl TreeNodeIndex { #[inline] fn ceil_ilog2(branch_num: u64) -> u32 { - // floor(log2(n - 1)) + 1 == ceil(log2(n)) - (branch_num - 1).checked_ilog2().map_or(0, |b| b + 1) + // Using `wrapping_sub` for optimization, consider `log(0)` to be undefined + // `floor(log2(n - 1)) + 1 == ceil(log2(n))` + branch_num.wrapping_sub(1).checked_ilog2().map_or(0, |b| b.wrapping_add(1)) } /// Error for exhausting free bits From ee40fee423e5de1be74c5712f05a132b0a6b3481 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Wed, 25 Feb 2026 18:15:17 +0300 Subject: [PATCH 08/11] Get rid of BranchingError as it is no longer used --- .../src/tree_node_index.rs | 32 ++++--------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/compiler/rustc_data_structures/src/tree_node_index.rs b/compiler/rustc_data_structures/src/tree_node_index.rs index 6199f7ec20453..ff332dc70ca1e 100644 --- a/compiler/rustc_data_structures/src/tree_node_index.rs +++ b/compiler/rustc_data_structures/src/tree_node_index.rs @@ -1,6 +1,3 @@ -use std::error::Error; -use std::fmt::Display; - /// Ordered index for dynamic trees /// /// ## Encoding @@ -37,13 +34,13 @@ impl TreeNodeIndex { } /// Append tree branch no. `branch_idx` reserving `bits` bits. - fn try_bits_branch(self, branch_idx: u64, bits: u32) -> Result { + fn try_bits_branch(self, branch_idx: u64, bits: u32) -> Option { let trailing_zeros = self.0.trailing_zeros(); - let allocated_shift = trailing_zeros.checked_sub(bits).ok_or(BranchingError(()))?; + let allocated_shift = trailing_zeros.checked_sub(bits)?; // Using wrapping operations for optimization, as edge cases are unreachable: // - `trailing_zeros < 64` as we are guaranteed at least one bit is set // - `allocated_shift == trailing_zeros - bits <= trailing_zeros < 64` - Ok(TreeNodeIndex( + Some(TreeNodeIndex( self.0 & !u64::wrapping_shl(1, trailing_zeros) | u64::wrapping_shl(1, allocated_shift) | branch_idx.unbounded_shl(allocated_shift.wrapping_add(1)), @@ -58,7 +55,9 @@ impl TreeNodeIndex { ); // `branch_num != 0` per debug assertion above let bits = ceil_ilog2(branch_num); - self.try_bits_branch(branch_idx, bits).unwrap() + self.try_bits_branch(branch_idx, bits).expect( + "TreeNodeIndex's free bits have been exhausted, make sure recursion is used carefully", + ) } } @@ -69,24 +68,5 @@ fn ceil_ilog2(branch_num: u64) -> u32 { branch_num.wrapping_sub(1).checked_ilog2().map_or(0, |b| b.wrapping_add(1)) } -/// Error for exhausting free bits -#[derive(Debug)] -pub struct BranchingError(()); - -impl Error for BranchingError {} - -impl Display for BranchingError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - "TreeNodeIndex's free bits have been exhausted, make sure recursion is used carefully" - .fmt(f) - } -} - -impl Default for TreeNodeIndex { - fn default() -> Self { - TreeNodeIndex::root() - } -} - #[cfg(test)] mod tests; From cf6f16afd553c89d692493ad3c41e1415f9a3d91 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Wed, 25 Feb 2026 19:52:19 +0300 Subject: [PATCH 09/11] Expand comments for TreeNodeIndex --- .../src/tree_node_index.rs | 80 +++++++++++++------ 1 file changed, 54 insertions(+), 26 deletions(-) diff --git a/compiler/rustc_data_structures/src/tree_node_index.rs b/compiler/rustc_data_structures/src/tree_node_index.rs index ff332dc70ca1e..d29d5a2d4ff81 100644 --- a/compiler/rustc_data_structures/src/tree_node_index.rs +++ b/compiler/rustc_data_structures/src/tree_node_index.rs @@ -11,52 +11,80 @@ /// 0bXXXXXXX100000000...0 /// ``` /// -/// Node reach after traversal of `LRLRRLLR` branches should be represented as `0b0101100110000...0`. -/// Root is obviously encoded as `0b10000...0`. +/// The reached node after traversal of `LRLRRLLR` branches (`L` for left, `R` for right) should be +/// represented as `0b0101100110000...0`. +/// Root is encoded as `0b10000...0` from an empty bitstring we don't need to traverse any branch +/// to reach a binary tree's root. +/// +/// Here are some examples: +/// +/// ```text +/// (root) -> 0b10000000...0 +/// L (left) -> 0b01000000...0 +/// R (right) -> 0b11000000...0 +/// LL -> 0b00100000...0 +/// RLR -> 0b10110000...0 +/// LRL -> 0b01010000...0 +/// LRRLR -> 0b01101100...0 +/// ``` +/// +/// ## Multi-way tree +/// +/// But we don't necessary need to encode a binary tree directly. +/// We can imagine some node to have `N` number of branches instead of two: right and left. +/// We encode `0 <= i < N` numbered branches by interpreting `i`'s binary representation as +/// bitstring for a binary tree traversal. +/// +/// For example `N = 3`. Notice how right-most leaf node is unused: +/// +/// ```text +/// root +/// root / \ +/// / | \ => . . +/// 0 1 2 / \ / \ +/// 0 1 2 - +/// ``` /// /// ## Order /// -/// Encoding allows to sort nodes in left < parent < right linear order. -/// If you only consider leaves of a tree then those are sorted in order left < right. +/// Encoding allows to sort nodes in `left < parent < right` linear order. +/// If you only consider leaves of a tree then those are sorted in order `left < right`. /// /// ## Used in /// -/// Primary purpose of `TreeNodeIndex` is to track order of parallel tasks of functions like `join` -/// or `scope` (see `rustc_middle::sync`). -/// This is done in query cycle handling code to determine **intended** first task for a single- -/// threaded compiler front-end to execute even while multi-threaded. +/// Primary purpose of `TreeNodeIndex` is to track order of parallel tasks of functions like +/// `par_join`, `par_slice`, and others (see `rustc_middle::sync`). +/// This is done in query cycle handling code to determine **intended** first task for a +/// single-threaded compiler front-end to execute even while multi-threaded. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct TreeNodeIndex(u64); impl TreeNodeIndex { + /// Root node of a tree pub const fn root() -> Self { Self(0x80000000_00000000) } - /// Append tree branch no. `branch_idx` reserving `bits` bits. - fn try_bits_branch(self, branch_idx: u64, bits: u32) -> Option { + /// Append branch `i` out of `n` branches total to `TreeNodeIndex`'s traversal representation. + /// + /// This method reserves `ceil(log2(n))` bits within `TreeNodeIndex`'s integer encoded + /// bitstring. + pub fn branch(self, i: u64, n: u64) -> TreeNodeIndex { + debug_assert!(i < n, "i = {i} should be less than n = {n}"); + // `branch_num != 0` per debug assertion above + let bits = ceil_ilog2(n); + let trailing_zeros = self.0.trailing_zeros(); - let allocated_shift = trailing_zeros.checked_sub(bits)?; + let allocated_shift = trailing_zeros.checked_sub(bits).expect( + "TreeNodeIndex's free bits have been exhausted, make sure recursion is used carefully", + ); // Using wrapping operations for optimization, as edge cases are unreachable: // - `trailing_zeros < 64` as we are guaranteed at least one bit is set // - `allocated_shift == trailing_zeros - bits <= trailing_zeros < 64` - Some(TreeNodeIndex( + TreeNodeIndex( self.0 & !u64::wrapping_shl(1, trailing_zeros) | u64::wrapping_shl(1, allocated_shift) - | branch_idx.unbounded_shl(allocated_shift.wrapping_add(1)), - )) - } - - /// Append tree branch no. `branch_idx` reserving `ceil(log2(branch_num))` bits. - pub fn branch(self, branch_idx: u64, branch_num: u64) -> TreeNodeIndex { - debug_assert!( - branch_idx < branch_num, - "branch_idx = {branch_idx} should be less than branch_num = {branch_num}" - ); - // `branch_num != 0` per debug assertion above - let bits = ceil_ilog2(branch_num); - self.try_bits_branch(branch_idx, bits).expect( - "TreeNodeIndex's free bits have been exhausted, make sure recursion is used carefully", + | i.unbounded_shl(allocated_shift.wrapping_add(1)), ) } } From 6e39f242c2360546b4bdb8286160bced6a2b55b4 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Thu, 26 Feb 2026 16:22:06 +0300 Subject: [PATCH 10/11] Clarify how TreeNodeIndex could exhaust its u64 encoding and panic --- compiler/rustc_data_structures/src/tree_node_index.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_data_structures/src/tree_node_index.rs b/compiler/rustc_data_structures/src/tree_node_index.rs index d29d5a2d4ff81..4907eadbe7498 100644 --- a/compiler/rustc_data_structures/src/tree_node_index.rs +++ b/compiler/rustc_data_structures/src/tree_node_index.rs @@ -71,13 +71,18 @@ impl TreeNodeIndex { /// bitstring. pub fn branch(self, i: u64, n: u64) -> TreeNodeIndex { debug_assert!(i < n, "i = {i} should be less than n = {n}"); - // `branch_num != 0` per debug assertion above + // `n != 0` per debug assertion above let bits = ceil_ilog2(n); - let trailing_zeros = self.0.trailing_zeros(); + + // For this panic to happen there has to be a recursive function that isn't a query and + // uses par_join or par_slice recursively. + // Each query starts with a fresh binary tree, so we can expect this to never happen. + // That is unless someone writes 64 nested par_join calls or something equivalent. let allocated_shift = trailing_zeros.checked_sub(bits).expect( "TreeNodeIndex's free bits have been exhausted, make sure recursion is used carefully", ); + // Using wrapping operations for optimization, as edge cases are unreachable: // - `trailing_zeros < 64` as we are guaranteed at least one bit is set // - `allocated_shift == trailing_zeros - bits <= trailing_zeros < 64` From 15674b4d5a3d84e4b32699601526b25f49dbb899 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Thu, 26 Feb 2026 16:56:07 +0300 Subject: [PATCH 11/11] Do small refactor and add comments in rustc_middle::sync --- compiler/rustc_middle/src/sync.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/compiler/rustc_middle/src/sync.rs b/compiler/rustc_middle/src/sync.rs index 63929a8d87b41..8ac3000212a18 100644 --- a/compiler/rustc_middle/src/sync.rs +++ b/compiler/rustc_middle/src/sync.rs @@ -66,9 +66,11 @@ where let oper_a = FromDyn::from(oper_a); let oper_b = FromDyn::from(oper_b); let (a, b) = parallel_guard(|guard| { - raw_branched_join( - move || guard.run(move || FromDyn::from(oper_a.into_inner()())), - move || guard.run(move || FromDyn::from(oper_b.into_inner()())), + let task_a = move || guard.run(move || FromDyn::from(oper_a.into_inner()())); + let task_b = move || guard.run(move || FromDyn::from(oper_b.into_inner()())); + rustc_thread_pool::join( + || branch_context(0, 2, task_a), + || branch_context(1, 2, task_b), ) }); (a.unwrap().into_inner(), b.unwrap().into_inner()) @@ -197,15 +199,11 @@ pub fn par_map, R: DynSend, C: FromIterato }) } -fn raw_branched_join(oper_a: A, oper_b: B) -> (RA, RB) -where - A: FnOnce() -> RA + Send, - B: FnOnce() -> RB + Send, -{ - rustc_thread_pool::join(|| branch_context(0, 2, oper_a), || branch_context(1, 2, oper_b)) -} - -fn branch_context(branch_num: u64, branch_space: u64, f: F) -> R +/// Append `i`-th branch out of `n` branches to `icx.query.branch` to track inside of +/// which parallel task every query call is performed. +/// +/// See [`rustc_data_structures::tree_node_index::TreeNodeIndex`]. +fn branch_context(i: u64, n: u64, f: F) -> R where F: FnOnce() -> R, { @@ -214,7 +212,7 @@ where && let Some(QueryInclusion { id, branch }) = icx.query { let icx = tls::ImplicitCtxt { - query: Some(QueryInclusion { id, branch: branch.branch(branch_num, branch_space) }), + query: Some(QueryInclusion { id, branch: branch.branch(i, n) }), ..*icx }; tls::enter_context(&icx, f)