Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
use crate::aggregates::group_values::GroupValues;
use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
use arrow::array::{
ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, PrimitiveArray,
cast::AsArray,
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder,
PrimitiveArray, cast::AsArray,
};
use arrow::buffer::NullBuffer;
use arrow::datatypes::{DataType, i256};
use datafusion_common::Result;
use datafusion_common::hash_utils::RandomState;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
use half::f16;
use hashbrown::hash_table::HashTable;
#[cfg(not(feature = "force_hash_collisions"))]
use std::hash::BuildHasher;
use std::mem::size_of;
use std::sync::Arc;
Expand Down Expand Up @@ -109,46 +111,224 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
}
}

impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T>
where
T::Native: HashValue,
{
/// Find the group index for `key` if it already exists in the map.
#[inline(always)]
fn find_group(
map: &HashTable<(usize, u64)>,
values: &[T::Native],
key: T::Native,
hash: u64,
) -> Option<usize> {
// SAFETY: `g` is always a valid index into `values` because it was set
// to `values.len()` at insertion time and values are never removed
// (only via emit, which also clears or adjusts the map).
map.find(hash, |&(g, h)| unsafe {
hash == h && values.get_unchecked(g).is_eq(key)
})
.map(|&(g, _)| g)
}

/// Insert a new group for `key` that is known not to exist yet.
#[inline(always)]
fn insert_new_group(
map: &mut HashTable<(usize, u64)>,
values: &mut Vec<T::Native>,
key: T::Native,
hash: u64,
) -> usize {
let g = values.len();
values.push(key);
map.insert_unique(hash, (g, hash), |&(_, h)| h);
g
}

/// Find an existing group or insert a new one.
#[inline(always)]
fn lookup_or_insert(
map: &mut HashTable<(usize, u64)>,
values: &mut Vec<T::Native>,
key: T::Native,
hash: u64,
) -> usize {
if let Some(g) = Self::find_group(map, values, key, hash) {
g
} else {
Self::insert_new_group(map, values, key, hash)
}
}

/// Get or create the null group index.
#[inline(always)]
fn get_or_create_null_group(
null_group: &mut Option<usize>,
values: &mut Vec<T::Native>,
) -> usize {
*null_group.get_or_insert_with(|| {
let g = values.len();
values.push(Default::default());
g
})
}
}

impl<T: ArrowPrimitiveType> GroupValues for GroupValuesPrimitive<T>
where
T::Native: HashValue,
{
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
assert_eq!(cols.len(), 1);
let array = cols[0].as_primitive::<T>();
let len = array.len();
groups.clear();
groups.reserve(len);

if len == 0 {
return Ok(());
}

let values = array.values().as_ref();
let nulls: Option<&NullBuffer> = array.nulls();

// Step 1: Batch compute all hashes using thread-local buffer
// (avoids per-call allocation and separates hashing from hash table ops
// for better vectorization / instruction pipelining)
datafusion_common::hash_utils::with_hashes(cols, &self.random_state, |hashes| {
// Step 2: Process in chunks of 4 for better ILP and local dedup
let value_chunks = values.chunks_exact(4);
let hash_chunks = hashes.chunks_exact(4);
let rem_len = value_chunks.remainder().len();

for (chunk_idx, (vs, hs)) in value_chunks.zip(hash_chunks).enumerate() {
let base = chunk_idx * 4;
let mut gids = [0usize; 4];

// Local dedup within the chunk: dedup[i] = index of first
// equivalent entry in this chunk (i itself if unique).
// Compare only values (not hashes) since for 4 elements
// the value comparison is cheap enough.
let mut dedup = [0u8, 1, 2, 3];

if let Some(nulls) = nulls {
let valid = [
nulls.is_valid(base),
nulls.is_valid(base + 1),
nulls.is_valid(base + 2),
nulls.is_valid(base + 3),
];

for i in 1..4 {
for j in 0..i {
if (!valid[i] && !valid[j])
|| (valid[i] && valid[j] && vs[i].is_eq(vs[j]))
{
dedup[i] = dedup[j];
break;
}
}
}

// Phase 1: Batch find - lookup all unique entries
let mut found = [None; 4];
for i in 0..4 {
if dedup[i] as usize != i {
continue;
}
if !valid[i] {
found[i] = self.null_group;
} else {
found[i] =
Self::find_group(&self.map, &self.values, vs[i], hs[i]);
}
}

// Phase 2: Insert entries not found
for i in 0..4 {
if dedup[i] as usize != i {
gids[i] = gids[dedup[i] as usize];
continue;
}
if let Some(g) = found[i] {
gids[i] = g;
} else if !valid[i] {
gids[i] = Self::get_or_create_null_group(
&mut self.null_group,
&mut self.values,
);
} else {
gids[i] = Self::insert_new_group(
&mut self.map,
&mut self.values,
vs[i],
hs[i],
);
}
}
} else {
// Fast path: no nulls
for i in 1..4 {
for j in 0..i {
if vs[i].is_eq(vs[j]) {
dedup[i] = dedup[j];
break;
}
}
}

for v in cols[0].as_primitive::<T>() {
let group_id = match v {
None => *self.null_group.get_or_insert_with(|| {
let group_id = self.values.len();
self.values.push(Default::default());
group_id
}),
Some(key) => {
let state = &self.random_state;
let hash = key.hash(state);
let insert = self.map.entry(
hash,
|&(g, h)| unsafe {
hash == h && self.values.get_unchecked(g).is_eq(key)
},
|&(_, h)| h,
);

match insert {
hashbrown::hash_table::Entry::Occupied(o) => o.get().0,
hashbrown::hash_table::Entry::Vacant(v) => {
let g = self.values.len();
v.insert((g, hash));
self.values.push(key);
g
// Phase 1: Batch find - lookup all unique entries
let mut found = [None; 4];
for i in 0..4 {
if dedup[i] as usize != i {
continue;
}
found[i] =
Self::find_group(&self.map, &self.values, vs[i], hs[i]);
}

// Phase 2: Insert entries not found
for i in 0..4 {
if dedup[i] as usize != i {
gids[i] = gids[dedup[i] as usize];
} else if let Some(g) = found[i] {
gids[i] = g;
} else {
gids[i] = Self::insert_new_group(
&mut self.map,
&mut self.values,
vs[i],
hs[i],
);
}
}
}
};
groups.push(group_id)
}
Ok(())

groups.extend_from_slice(&gids);
}

// Handle remainder (0-3 elements)
let rem_start = len - rem_len;
for i in 0..rem_len {
let idx = rem_start + i;
let is_valid = nulls.is_none_or(|n: &NullBuffer| n.is_valid(idx));

let group_id = if !is_valid {
Self::get_or_create_null_group(&mut self.null_group, &mut self.values)
} else {
Self::lookup_or_insert(
&mut self.map,
&mut self.values,
values[idx],
hashes[idx],
)
};
groups.push(group_id);
}

Ok(())
})
}

fn size(&self) -> usize {
Expand Down
Loading