From bf0b336e892b76c0fcd73ab64426014d8c1d859f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Mar 2026 17:49:55 +0530 Subject: [PATCH 1/2] feat: use state priorities in bareiss pivot selection --- src/singularity_removal.jl | 115 +++++++++++++++++++++++++++++++------ 1 file changed, 99 insertions(+), 16 deletions(-) diff --git a/src/singularity_removal.jl b/src/singularity_removal.jl index ca2814d..dc67b32 100644 --- a/src/singularity_removal.jl +++ b/src/singularity_removal.jl @@ -55,7 +55,7 @@ the `constraint`. @inline function find_first_linear_variable(M::SparseMatrixCLIL, range, mask, - constraint) + constraint, ::Nothing = nothing) eadj = M.row_cols @inbounds for i in range vertices = eadj[i] @@ -70,10 +70,33 @@ the `constraint`. return nothing end +@inline function find_first_linear_variable( + M::SparseMatrixCLIL, + range, + mask, + constraint, var_priorities::AbstractVector{Int} + ) + eadj = M.row_cols + @inbounds for i in range + vertices = eadj[i] + constraint(length(vertices)) || continue + candidate_v = 0 + candidate_val = 0 + for (j, v) in enumerate(vertices) + mask === nothing || mask[v] || continue + iszero(candidate_v) || var_priorities[v] < var_priorities[candidate_v] || continue + candidate_v = v + candidate_val = M.row_vals[i][j] + end + iszero(candidate_v) || return CartesianIndex(i, candidate_v), candidate_val + end + return nothing +end + @inline function find_first_linear_variable(M::AbstractMatrix, range, mask, - constraint) + constraint, ::Nothing = nothing) @inbounds for i in range row = @view M[i, :] if constraint(count(!iszero, row)) @@ -87,12 +110,36 @@ end return nothing end -function find_masked_pivot(variables, M, k) - r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(1)) +@inline function find_first_linear_variable( + M::AbstractMatrix, + range, + mask, + constraint, var_priorities::AbstractVector{Int} + ) + @inbounds for i in range + row = @view M[i, :] + constraint(count(!iszero, row)) || continue + candidate_v = 0 + candidate_val = 0 + for (v, val) in enumerate(row) + mask === nothing || mask[v] || continue + if iszero(candidate_v) || var_priorities[v] < var_priorities[candidate_v] + candidate_v = v + candidate_val = val + end + end + iszero(candidate_v) && return nothing + return CartesianIndex(i, candidate_v), candidate_val + end + return nothing +end + +function find_masked_pivot(variables, M, k, var_priorities) + r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(1), var_priorities) r !== nothing && return r - r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(2)) + r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(2), var_priorities) r !== nothing && return r - r = find_first_linear_variable(M, k:size(M, 1), variables, _ -> true) + r = find_first_linear_variable(M, k:size(M, 1), variables, _ -> true, var_priorities) return r end @@ -207,14 +254,15 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti} end end solvable_variables = findall(is_linear_variables) + var_priorities = has_state_priorities(structure) ? get_state_priorities(structure) : nothing local bar try - bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff) + bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities) catch e e isa OverflowError || rethrow(e) mm = convert(SparseMatrixCLIL{BigInt, Ti}, mm_orig) - bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff) + bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities) end # This phrasing infers the return type as `Union{Tuple{...}}` instead of @@ -243,6 +291,18 @@ end (s::SyncedSwapRows{Nothing})(M, i::Int, j::Int) = Base.swaprows!(M, i, j) (s::SyncedSwapRows)(M, i::Int, j::Int) = (Base.swaprows!(s.Mold, i, j); Base.swaprows!(M, i, j)) +""" + $TYPEDEF + +Lazy `&&` of two boolean masks. Only implements whatever is required for `find_masked_pivot`. +""" +struct LazyMaskAnd{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}} + mask1::V1 + mask2::V2 +end + +Base.getindex(lma::LazyMaskAnd, i::Integer) = lma.mask1[i] && lma.mask2[i] + """ $(TYPEDEF) @@ -253,12 +313,21 @@ Mutable state threaded through the Bareiss factorization callbacks. - `pivots`: accumulates the column index of every pivot chosen during elimination. - `is_linear_variables`/`is_highest_diff`: masks used for the tiered pivot search. """ -mutable struct BareissContext{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}} +mutable struct BareissContext{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}, P <: Union{Nothing, AbstractVector{Int}}} rank1::Union{Nothing, Int} rank2::Union{Nothing, Int} pivots::Vector{Int} is_linear_variables::V1 is_highest_diff::V2 + valid_pivot_mask::BitVector + var_priorities::P +end + +function BareissContext(is_linear_variables, is_highest_diff, var_priorities = nothing) + return BareissContext( + nothing, nothing, Int[], is_linear_variables, is_highest_diff, + trues(length(is_linear_variables)), var_priorities + ) end """ @@ -273,7 +342,8 @@ The column index of every selected pivot is appended to `ctx.pivots`. """ function (ctx::BareissContext)(M, k::Int) if ctx.rank1 === nothing - r = find_masked_pivot(ctx.is_linear_variables, M, k) + mask = LazyMaskAnd(ctx.is_linear_variables, ctx.valid_pivot_mask) + r = find_masked_pivot(ctx.is_linear_variables, M, k, ctx.var_priorities) if r !== nothing push!(ctx.pivots, r[1][2]) return r @@ -281,7 +351,8 @@ function (ctx::BareissContext)(M, k::Int) ctx.rank1 = k - 1 end if ctx.rank2 === nothing - r = find_masked_pivot(ctx.is_highest_diff, M, k) + mask = LazyMaskAnd(ctx.is_highest_diff, ctx.valid_pivot_mask) + r = find_masked_pivot(ctx.is_highest_diff, M, k, ctx.var_priorities) if r !== nothing push!(ctx.pivots, r[1][2]) return r @@ -291,16 +362,28 @@ function (ctx::BareissContext)(M, k::Int) # TODO: It would be better to sort the variables by # derivative order here to enable more elimination # opportunities. - r = find_masked_pivot(nothing, M, k) + r = find_masked_pivot(nothing, M, k, ctx.var_priorities) r !== nothing && push!(ctx.pivots, r[1][2]) return r end +struct BareissContextUpdate{C <: BareissContext, F} + context::C + inner_update::F +end + +function (bcu::BareissContextUpdate)(zero!, M, k, swapto, pivot, last_pivot; kw...) + ctx = bcu.context + col = swapto[2] + ctx.valid_pivot_mask[col] = false + return bcu.inner_update(zero!, M, k, swapto, pivot, last_pivot; kw...) +end + function do_bareiss!(M, Mold, is_linear_variables::AbstractVector{Bool}, - is_highest_diff::AbstractVector{Bool}) - ctx = BareissContext(nothing, nothing, Int[], is_linear_variables, is_highest_diff) - bareiss_ops = (noop_colswap, SyncedSwapRows(Mold), - bareiss_update_virtual_colswap_mtk!, bareiss_zero!) + is_highest_diff::AbstractVector{Bool}, var_priorities = nothing) + ctx = BareissContext(is_linear_variables, is_highest_diff, var_priorities) + update! = BareissContextUpdate(ctx, bareiss_update_virtual_colswap_mtk!) + bareiss_ops = (noop_colswap, SyncedSwapRows(Mold), update!, bareiss_zero!) rank3, = bareiss!(M, bareiss_ops; find_pivot = ctx) rank2 = something(ctx.rank2, rank3) rank1 = something(ctx.rank1, rank2) From 3d363c03372ab09b8c7208a0247c7024f4e105e2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Mar 2026 17:45:45 +0530 Subject: [PATCH 2/2] feat: allow returning additional pivot information from alias elimination --- src/singularity_removal.jl | 59 ++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/src/singularity_removal.jl b/src/singularity_removal.jl index dc67b32..1e558df 100644 --- a/src/singularity_removal.jl +++ b/src/singularity_removal.jl @@ -15,15 +15,21 @@ end level === nothing ? v : (v => level) end -function structural_singularity_removal!(state::TransformationState; - variable_underconstrained! = force_var_to_zero!, kwargs...) +function structural_singularity_removal!( + state::TransformationState, ::Val{ReturnPivots} = Val{false}(); + variable_underconstrained! = force_var_to_zero!, kwargs... + ) where {ReturnPivots} mm = linear_subsys_adjmat!(state; kwargs...) if size(mm, 1) == 0 - return mm # No linear subsystems + if ReturnPivots + return mm, PivotInfo(0, 0, Int[]) + else + return mm # No linear subsystems + end end (; graph, var_to_diff, solvable_graph) = state.structure - mm = structural_singularity_removal!(state, mm; variable_underconstrained!) + mm, pivotinfo = structural_singularity_removal!(state, mm, Val{true}(); variable_underconstrained!) s = state.structure for (ei, e) in enumerate(mm.nzrows) set_neighbors!(s.graph, e, mm.row_cols[ei]) @@ -34,7 +40,11 @@ function structural_singularity_removal!(state::TransformationState; end end - return mm + if ReturnPivots + return mm, pivotinfo + else + return mm + end end # For debug purposes @@ -404,8 +414,37 @@ function force_var_to_zero!(structure::SystemStructure, ils::SparseMatrixCLIL, v return ils end -function structural_singularity_removal!(state::TransformationState, ils::SparseMatrixCLIL; - variable_underconstrained! = force_var_to_zero!) +""" + $TYPEDSIGNATURES + +Information about the pivots chosen by Bareiss during `structural_singularity_removal!`. +This can be returned from `structural_singularity_removal!` by passing `Val(true)` as the last +positional argument. + +$TYPEDFIELDS +""" +struct PivotInfo + """ + The length of the prefix of `pivots` that is variables which _only_ occur in linear + equations of the sort considered by this pass. These variables must be solved for + using the integer coefficient equations considered by this pass. + """ + n_linear_vars::Int + """ + Number of elements in `pivots` after `n_linear_vars` corresponding to highest order + derivative variables. + """ + n_highest_diff_vars::Int + """ + The list of pivots chosen by the Bareiss algorithm. + """ + pivots::Vector{Int} +end + +function structural_singularity_removal!( + state::TransformationState, ils::SparseMatrixCLIL, ::Val{ReturnPivots} = Val{false}(); + variable_underconstrained! = force_var_to_zero! + ) where {ReturnPivots} (; structure) = state (; graph, solvable_graph, var_to_diff, eq_to_diff) = state.structure # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear @@ -420,5 +459,9 @@ function structural_singularity_removal!(state::TransformationState, ils::Sparse ils = variable_underconstrained!(structure, ils, v) end - return ils + if ReturnPivots + return ils, PivotInfo(rank1, rank2, pivots) + else + return ils + end end