Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_cranelift/src/driver/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ use rustc_codegen_ssa::base::determine_cgu_reuse;
use rustc_codegen_ssa::{CodegenResults, CompiledModule, CrateInfo, ModuleKind};
use rustc_data_structures::profiling::SelfProfilerRef;
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
use rustc_data_structures::sync::{IntoDynSyncSend, par_map};
use rustc_hir::attrs::Linkage as RLinkage;
use rustc_middle::dep_graph::{WorkProduct, WorkProductId};
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags;
use rustc_middle::mir::mono::{CodegenUnit, MonoItem, MonoItemData, Visibility};
use rustc_middle::sync::{IntoDynSyncSend, par_map};
use rustc_session::Session;
use rustc_session::config::{OutputFilenames, OutputType};

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use rustc_ast::expand::allocator::{
};
use rustc_data_structures::fx::{FxHashMap, FxIndexSet};
use rustc_data_structures::profiling::{get_resident_set_size, print_time_passes_entry};
use rustc_data_structures::sync::{IntoDynSyncSend, par_map};
use rustc_data_structures::unord::UnordMap;
use rustc_hir::attrs::{DebuggerVisualizerType, OptimizeAttr};
use rustc_hir::def_id::{DefId, LOCAL_CRATE};
Expand All @@ -26,6 +25,7 @@ use rustc_middle::mir::BinOp;
use rustc_middle::mir::interpret::ErrorHandled;
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
use rustc_middle::query::Providers;
use rustc_middle::sync::{IntoDynSyncSend, par_map};
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_data_structures/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub mod temp_dir;
pub mod thinvec;
pub mod thousands;
pub mod transitive_relation;
pub mod tree_node_index;
pub mod unhash;
pub mod union_find;
pub mod unord;
Expand Down
5 changes: 1 addition & 4 deletions compiler/rustc_data_structures/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ pub use self::freeze::{FreezeLock, FreezeReadGuard, FreezeWriteGuard};
#[doc(no_inline)]
pub use self::lock::{Lock, LockGuard, Mode};
pub use self::mode::{is_dyn_thread_safe, set_dyn_thread_safe_mode};
pub use self::parallel::{
broadcast, par_fns, par_for_each_in, par_join, par_map, parallel_guard, spawn,
try_par_for_each_in,
};
pub use self::parallel::{ParallelGuard, broadcast, parallel_guard, spawn};
pub use self::vec::{AppendOnlyIndexVec, AppendOnlyVec};
pub use self::worker_local::{Registry, WorkerLocal};
pub use crate::marker::*;
Expand Down
159 changes: 0 additions & 159 deletions compiler/rustc_data_structures/src/sync/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,6 @@ pub fn parallel_guard<R>(f: impl FnOnce(&ParallelGuard) -> R) -> R {
ret
}

fn serial_join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful in the commit message to add a brief explanation why these functions are being moved. Something like "Because the next commit will modify them to use ImplicitCtxt which is not available in rustc_data_structures."

where
A: FnOnce() -> RA,
B: FnOnce() -> RB,
{
let (a, b) = parallel_guard(|guard| {
let a = guard.run(oper_a);
let b = guard.run(oper_b);
(a, b)
});
(a.unwrap(), b.unwrap())
}

pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
if mode::is_dyn_thread_safe() {
let func = FromDyn::from(func);
Expand All @@ -67,152 +54,6 @@ pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
}
}

/// Runs the functions in parallel.
///
/// The first function is executed immediately on the current thread.
/// Use that for the longest running function for better scheduling.
pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) {
parallel_guard(|guard: &ParallelGuard| {
if mode::is_dyn_thread_safe() {
let funcs = FromDyn::from(funcs);
rustc_thread_pool::scope(|s| {
let Some((first, rest)) = funcs.into_inner().split_at_mut_checked(1) else {
return;
};

// Reverse the order of the later functions since Rayon executes them in reverse
// order when using a single thread. This ensures the execution order matches
// that of a single threaded rustc.
for f in rest.iter_mut().rev() {
let f = FromDyn::from(f);
s.spawn(|_| {
guard.run(|| (f.into_inner())());
});
}

// Run the first function without spawning to
// ensure it executes immediately on this thread.
guard.run(|| first[0]());
});
} else {
for f in funcs {
guard.run(|| f());
}
}
});
}

#[inline]
pub fn par_join<A, B, RA: DynSend, RB: DynSend>(oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce() -> RA + DynSend,
B: FnOnce() -> RB + DynSend,
{
if mode::is_dyn_thread_safe() {
let oper_a = FromDyn::from(oper_a);
let oper_b = FromDyn::from(oper_b);
let (a, b) = parallel_guard(|guard| {
rustc_thread_pool::join(
move || guard.run(move || FromDyn::from(oper_a.into_inner()())),
move || guard.run(move || FromDyn::from(oper_b.into_inner()())),
)
});
(a.unwrap().into_inner(), b.unwrap().into_inner())
} else {
serial_join(oper_a, oper_b)
}
}

fn par_slice<I: DynSend>(
items: &mut [I],
guard: &ParallelGuard,
for_each: impl Fn(&mut I) + DynSync + DynSend,
) {
let for_each = FromDyn::from(for_each);
let mut items = for_each.derive(items);
rustc_thread_pool::scope(|s| {
let proof = items.derive(());
let group_size = std::cmp::max(items.len() / 128, 1);
for group in items.chunks_mut(group_size) {
let group = proof.derive(group);
s.spawn(|_| {
let mut group = group;
for i in group.iter_mut() {
guard.run(|| for_each(i));
}
});
}
});
}

pub fn par_for_each_in<I: DynSend, T: IntoIterator<Item = I>>(
t: T,
for_each: impl Fn(&I) + DynSync + DynSend,
) {
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
let mut items: Vec<_> = t.into_iter().collect();
par_slice(&mut items, guard, |i| for_each(&*i))
} else {
t.into_iter().for_each(|i| {
guard.run(|| for_each(&i));
});
}
});
}

/// This runs `for_each` in parallel for each iterator item. If one or more of the
/// `for_each` calls returns `Err`, the function will also return `Err`. The error returned
/// will be non-deterministic, but this is expected to be used with `ErrorGuaranteed` which
/// are all equivalent.
pub fn try_par_for_each_in<T: IntoIterator, E: DynSend>(
t: T,
for_each: impl Fn(&<T as IntoIterator>::Item) -> Result<(), E> + DynSync + DynSend,
) -> Result<(), E>
where
<T as IntoIterator>::Item: DynSend,
{
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
let mut items: Vec<_> = t.into_iter().collect();

let error = Mutex::new(None);

par_slice(&mut items, guard, |i| {
if let Err(err) = for_each(&*i) {
*error.lock() = Some(err);
}
});

if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) }
} else {
t.into_iter().filter_map(|i| guard.run(|| for_each(&i))).fold(Ok(()), Result::and)
}
})
}

pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterator<R>>(
t: T,
map: impl Fn(I) -> R + DynSync + DynSend,
) -> C {
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
let map = FromDyn::from(map);

let mut items: Vec<(Option<I>, Option<R>)> =
t.into_iter().map(|i| (Some(i), None)).collect();

par_slice(&mut items, guard, |i| {
i.1 = Some(map(i.0.take().unwrap()));
});

items.into_iter().filter_map(|i| i.1).collect()
} else {
t.into_iter().filter_map(|i| guard.run(|| map(i))).collect()
}
})
}

pub fn broadcast<R: DynSend>(op: impl Fn(usize) -> R + DynSync) -> Vec<R> {
if mode::is_dyn_thread_safe() {
let op = FromDyn::from(op);
Expand Down
105 changes: 105 additions & 0 deletions compiler/rustc_data_structures/src/tree_node_index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/// Ordered index for dynamic trees
///
/// ## Encoding
///
/// You can index any node of a binary tree with a variable-length bitstring.
/// Each bit should represent a branch/edge to traverse up from root to down to indexed node.
/// To encode a variable-length bitstring we use bits of u64 ordered from highest to lowest.
/// Right after the encoded sequence of bits we set one bit high to recover sequence's length:
///
/// ```text
/// 0bXXXXXXX100000000...0
/// ```
///
/// The reached node after traversal of `LRLRRLLR` branches (`L` for left, `R` for right) should be
/// represented as `0b0101100110000...0`.
/// Root is encoded as `0b10000...0` from an empty bitstring we don't need to traverse any branch
/// to reach a binary tree's root.
///
/// Here are some examples:
///
/// ```text
/// (root) -> 0b10000000...0
/// L (left) -> 0b01000000...0
/// R (right) -> 0b11000000...0
/// LL -> 0b00100000...0
/// RLR -> 0b10110000...0
/// LRL -> 0b01010000...0
/// LRRLR -> 0b01101100...0
/// ```
///
/// ## Multi-way tree
///
/// But we don't necessary need to encode a binary tree directly.
/// We can imagine some node to have `N` number of branches instead of two: right and left.
/// We encode `0 <= i < N` numbered branches by interpreting `i`'s binary representation as
/// bitstring for a binary tree traversal.
///
/// For example `N = 3`. Notice how right-most leaf node is unused:
///
/// ```text
/// root
/// root / \
/// / | \ => . .
/// 0 1 2 / \ / \
/// 0 1 2 -
/// ```
///
/// ## Order
///
/// Encoding allows to sort nodes in `left < parent < right` linear order.
/// If you only consider leaves of a tree then those are sorted in order `left < right`.
///
/// ## Used in
///
/// Primary purpose of `TreeNodeIndex` is to track order of parallel tasks of functions like
/// `par_join`, `par_slice`, and others (see `rustc_middle::sync`).
/// This is done in query cycle handling code to determine **intended** first task for a
/// single-threaded compiler front-end to execute even while multi-threaded.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct TreeNodeIndex(u64);

impl TreeNodeIndex {
/// Root node of a tree
pub const fn root() -> Self {
Self(0x80000000_00000000)
}

/// Append branch `i` out of `n` branches total to `TreeNodeIndex`'s traversal representation.
///
/// This method reserves `ceil(log2(n))` bits within `TreeNodeIndex`'s integer encoded
/// bitstring.
pub fn branch(self, i: u64, n: u64) -> TreeNodeIndex {
debug_assert!(i < n, "i = {i} should be less than n = {n}");
// `n != 0` per debug assertion above
let bits = ceil_ilog2(n);
let trailing_zeros = self.0.trailing_zeros();

// For this panic to happen there has to be a recursive function that isn't a query and
// uses par_join or par_slice recursively.
// Each query starts with a fresh binary tree, so we can expect this to never happen.
// That is unless someone writes 64 nested par_join calls or something equivalent.
let allocated_shift = trailing_zeros.checked_sub(bits).expect(
"TreeNodeIndex's free bits have been exhausted, make sure recursion is used carefully",
);

// Using wrapping operations for optimization, as edge cases are unreachable:
// - `trailing_zeros < 64` as we are guaranteed at least one bit is set
// - `allocated_shift == trailing_zeros - bits <= trailing_zeros < 64`
TreeNodeIndex(
self.0 & !u64::wrapping_shl(1, trailing_zeros)
| u64::wrapping_shl(1, allocated_shift)
| i.unbounded_shl(allocated_shift.wrapping_add(1)),
)
}
}

#[inline]
fn ceil_ilog2(branch_num: u64) -> u32 {
// Using `wrapping_sub` for optimization, consider `log(0)` to be undefined
// `floor(log2(n - 1)) + 1 == ceil(log2(n))`
branch_num.wrapping_sub(1).checked_ilog2().map_or(0, |b| b.wrapping_add(1))
}

#[cfg(test)]
mod tests;
45 changes: 45 additions & 0 deletions compiler/rustc_data_structures/src/tree_node_index/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::tree_node_index::TreeNodeIndex;

#[test]
fn up_to_16() {
for n in 1..128 {
for i in 0..n {
TreeNodeIndex::root().branch(i, n).branch(n - i - 1, n);
}
}
}

#[test]
fn ceil_log2() {
const EVALUATION_TABLE: [(u64, u32); 9] =
[(1, 0), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 3), (8, 3), (u64::MAX, 64)];
for &(x, y) in &EVALUATION_TABLE {
let r = super::ceil_ilog2(x);
assert!(r == y, "ceil_ilog2({x}) == {r} != {y}");
}
}

#[test]
fn some_cases() {
let mut tni = TreeNodeIndex::root();
tni = tni.branch(0xDEAD, 0xFADE);
assert_eq!(tni.0, 0xDEAD8000_00000000);
tni = tni.branch(0xBEEF, 0xCCCC);
assert_eq!(tni.0, 0xDEADBEEF_80000000);
tni = tni.branch(1, 2);
assert_eq!(tni.0, 0xDEADBEEF_C0000000);
tni = tni.branch(0, 2);
assert_eq!(tni.0, 0xDEADBEEF_A0000000);
tni = tni.branch(3, 4);
assert_eq!(tni.0, 0xDEADBEEF_B8000000);
tni = tni.branch(0xAAAAAA, 0xBBBBBB);
assert_eq!(tni.0, 0xDEADBEEF_BAAAAAA8);
}

#[test]
fn edge_cases() {
const ROOT: TreeNodeIndex = TreeNodeIndex::root();
assert_eq!(ROOT.branch(0, 1), TreeNodeIndex::root());
assert_eq!(ROOT.branch(u64::MAX >> 1, 1 << 63).0, u64::MAX);
assert_eq!(ROOT.branch(0, 1 << 63).0, 1);
}
2 changes: 1 addition & 1 deletion compiler/rustc_incremental/src/persist/save.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::fs;
use std::sync::Arc;

use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::sync::par_join;
use rustc_middle::dep_graph::{
DepGraph, SerializedDepGraph, WorkProduct, WorkProductId, WorkProductMap,
};
use rustc_middle::sync::par_join;
use rustc_middle::ty::TyCtxt;
use rustc_serialize::Encodable as RustcEncodable;
use rustc_serialize::opaque::{FileEncodeResult, FileEncoder};
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_interface/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ pub fn try_print_query_stack(
if let Some(icx) = icx {
ty::print::with_no_queries!(print_query_stack(
icx.tcx,
icx.query,
icx.query.map(|i| i.id),
dcx,
limit_frames,
file,
Expand Down
Loading