Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@g
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand All @@ -17,6 +18,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Calculus = "0.5.2"
DataStructures = "0.18, 0.19"
ForwardDiff = "1"
JuMP = "1.29.4"
MathOptInterface = "1.40"
NaNMath = "1"
SparseArrays = "1.10"
Expand Down
9 changes: 9 additions & 0 deletions perf/neural.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Needs https://github.com/jump-dev/JuMP.jl/pull/3451
using JuMP
using ArrayDiff

n = 2
X = rand(n, n)
model = Model()
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
W * X
2 changes: 2 additions & 0 deletions src/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ function Evaluator(
return Evaluator(model, NLPEvaluator(model, ordered_variables))
end

include("JuMP/JuMP.jl")

end # module
11 changes: 11 additions & 0 deletions src/JuMP/JuMP.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# JuMP extension

import JuMP

# Equivalent of `AbstractJuMPScalar` but for arrays
abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end

include("variables.jl")
include("nlp_expr.jl")
include("operators.jl")
include("print.jl")
16 changes: 16 additions & 0 deletions src/JuMP/nlp_expr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
struct GenericArrayExpr{V<:JuMP.AbstractVariableRef,N} <:
AbstractJuMPArray{JuMP.GenericNonlinearExpr{V},N}
head::Symbol
args::Vector{Any}
size::NTuple{N,Int}
end

const ArrayExpr{N} = GenericArrayExpr{JuMP.VariableRef,N}

function Base.getindex(::GenericArrayExpr, args...)
return error(
"`getindex` not implemented, build vectorized expression instead",
)
end

Base.size(expr::GenericArrayExpr) = expr.size
7 changes: 7 additions & 0 deletions src/JuMP/operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function Base.:(*)(A::MatrixOfVariables, B::Matrix)
return GenericArrayExpr{JuMP.variable_ref_type(A.model),2}(
:*,
Any[A, B],
(size(A, 1), size(B, 2)),
)
end
11 changes: 11 additions & 0 deletions src/JuMP/print.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function Base.show(io::IO, ::MIME"text/plain", v::ArrayOfVariables)
return print(io, Base.summary(v), " with offset ", v.offset)
end

function Base.show(io::IO, ::MIME"text/plain", v::GenericArrayExpr)
return print(io, Base.summary(v))
end

function Base.show(io::IO, v::AbstractJuMPArray)
return show(io, MIME"text/plain"(), v)
end
45 changes: 45 additions & 0 deletions src/JuMP/variables.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Taken out of GenOpt, we can add ArrayDiff as dependency to GenOpt and remove it in GenOpt

struct ArrayOfVariables{T,N} <: AbstractJuMPArray{JuMP.GenericVariableRef{T},N}
model::JuMP.GenericModel{T}
offset::Int64
size::NTuple{N,Int64}
end

const MatrixOfVariables{T} = ArrayOfVariables{T,2}

Base.size(array::ArrayOfVariables) = array.size
function Base.getindex(A::ArrayOfVariables{T}, I...) where {T}
index =
A.offset + Base._to_linear_index(Base.CartesianIndices(A.size), I...)
return JuMP.GenericVariableRef{T}(A.model, MOI.VariableIndex(index))
end

function JuMP.Containers.container(
f::Function,
indices::JuMP.Containers.VectorizedProductIterator{
NTuple{N,Base.OneTo{Int}},
},
::Type{ArrayOfVariables},
) where {N}
return to_generator(JuMP.Containers.container(f, indices, Array))
end

JuMP._is_real(::ArrayOfVariables) = true

function Base.convert(
::Type{ArrayOfVariables{T,N}},
array::Array{JuMP.GenericVariableRef{T},N},
) where {T,N}
model = JuMP.owner_model(array[1])
offset = JuMP.index(array[1]).value - 1
for i in eachindex(array)
@assert JuMP.owner_model(array[i]) === model
@assert JuMP.index(array[i]).value == offset + i
end
return ArrayOfVariables{T,N}(model, offset, size(array))
end

function to_generator(array::Array{JuMP.GenericVariableRef{T},N}) where {T,N}
return convert(ArrayOfVariables{T,N}, array)
end
43 changes: 43 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
module TestJuMP

using Test

using JuMP
using ArrayDiff

function runtests()
for name in names(@__MODULE__; all = true)
if startswith("$(name)", "test_")
@testset "$(name)" begin
getfield(@__MODULE__, name)()
end
end
end
return
end

function test_array_product()
n = 2
X = rand(n, n)
model = Model()
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
@test W isa ArrayDiff.MatrixOfVariables{Float64}
@test JuMP.index(W[1, 1]) == MOI.VariableIndex(1)
@test JuMP.index(W[2, 1]) == MOI.VariableIndex(2)
@test JuMP.index(W[2]) == MOI.VariableIndex(2)
@test sprint(show, W) ==
"2×2 ArrayDiff.ArrayOfVariables{Float64, 2} with offset 0"
prod = W * X
@test prod isa ArrayDiff.ArrayExpr{2}
@test sprint(show, prod) ==
"2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}"
err = ErrorException(
"`getindex` not implemented, build vectorized expression instead",
)
@test_throws err prod[1, 1]
return
end

end # module

TestJuMP.runtests()
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
[deps]
ArrayDiff = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
GenOpt = "f2c049d8-7489-4223-990c-4f1c121a4cde"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
ArrayDiff = {path = ".."}
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("ReverseAD.jl")
include("ArrayDiff.jl")
include("JuMP.jl")
Loading