Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
name = "ArrayDiff"
uuid = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"]
version = "0.1.0"
authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@gmail.com>"]

[deps]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Calculus = "0.5.2"
DataStructures = "0.18, 0.19"
ForwardDiff = "1"
MathOptInterface = "1.40"
NaNMath = "1"
SparseArrays = "1.10"
SpecialFunctions = "2.6.1"
julia = "1.10"
3 changes: 2 additions & 1 deletion src/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ import NaNMath:
include("Coloring/Coloring.jl")
include("graph_tools.jl")
include("sizes.jl")
include("univariate_expressions.jl")
include("operators.jl")
include("types.jl")
include("utils.jl")

include("reverse_mode.jl")
include("forward_over_reverse.jl")
include("mathoptinterface_api.jl")
include("operators.jl")
include("model.jl")
include("parse.jl")
include("evaluator.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/mathoptinterface_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

_no_hessian(op::MOI.Nonlinear._UnivariateOperator) = op.f′′ === nothing
_no_hessian(op::_UnivariateOperator) = op.f′′ === nothing
_no_hessian(op::MOI.Nonlinear._MultivariateOperator) = op.∇²f === nothing

function MOI.features_available(d::NLPEvaluator)
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function register_operator(
elseif haskey(registry.multivariate_operator_to_id, op)
error("Operator $op is already registered.")
end
operator = Nonlinear._UnivariateOperator(op, f...)
operator = _UnivariateOperator(op, f...)
push!(registry.univariate_operators, op)
push!(registry.registered_univariate_operators, operator)
registry.univariate_operator_to_id[op] =
Expand Down
200 changes: 198 additions & 2 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,158 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
:row,
]

function _validate_register_assumptions(
f::Function,
name::Symbol,
nb_args::Integer,
)
# Assumption 1: check that `f` can be called with `Float64` arguments.
arg = nb_args == 1 ? 0.0 : zeros(nb_args)
if hasmethod(f, Tuple{typeof(arg)})
y = f(arg)
else
error(
"Unable to register the function :$name.\n\n" *
"The function must be able to be called with $nb_args Float64 " *
"arguments, but no method was found for this.",
)
end
if !(y isa Real)
error(
"Expected return type of `Float64` from the user-defined " *
"function :$(name), but got `$(typeof(y))`.",
)
end
# Assumption 2: check that `f` can be differentiated using `ForwardDiff`.
try
if nb_args == 1
ForwardDiff.derivative(f, 0.0)
else
ForwardDiff.gradient(x -> f(x...), zeros(nb_args))
end
catch err
if err isa MethodError
error(
"Unable to register the function :$name.\n\n" *
_FORWARD_DIFF_METHOD_ERROR_HELPER,
)
end
# We hit some other error, perhaps we called a function like log(-1).
# Ignore for now, and hope that a useful error is shown to the user
# during the solve.
end
return
end

function _checked_derivative(f::F, op::Symbol) where {F}
return function (x)
try
return ForwardDiff.derivative(f, x)
catch err
_intercept_ForwardDiff_MethodError(err, op)
end
end
end

"""
check_return_type(::Type{T}, ret::S) where {T,S}

Overload this method for new types `S` to throw an informative error if a
user-defined function returns the type `S` instead of `T`.
"""
check_return_type(::Type{T}, ret::T) where {T} = nothing

function check_return_type(::Type{T}, ret) where {T}
return error(
"Expected return type of $T from a user-defined function, but got " *
"$(typeof(ret)).",
)
end

struct _UnivariateOperator{F,F′,F′′}
f::F
f′::F′
f′′::F′′
function _UnivariateOperator(
f::Function,
f′::Function,
f′′::Union{Nothing,Function} = nothing,
)
return new{typeof(f),typeof(f′),typeof(f′′)}(f, f′, f′′)
end
end

function _UnivariateOperator(op::Symbol, f::Function)
_validate_register_assumptions(f, op, 1)
f′ = _checked_derivative(f, op)
return _UnivariateOperator(op, f, f′)
end

function _UnivariateOperator(op::Symbol, f::Function, f′::Function)
try
_validate_register_assumptions(f′, op, 1)
f′′ = _checked_derivative(f′, op)
return _UnivariateOperator(f, f′, f′′)
catch
return _UnivariateOperator(f, f′, nothing)
end
end

function _UnivariateOperator(::Symbol, f::Function, f′::Function, f′′::Function)
return _UnivariateOperator(f, f′, f′′)
end

struct OperatorRegistry
# NODE_CALL_UNIVARIATE
univariate_operators::Vector{Symbol}
univariate_operator_to_id::Dict{Symbol,Int}
univariate_user_operator_start::Int
registered_univariate_operators::Vector{_UnivariateOperator}
# NODE_CALL_MULTIVARIATE
multivariate_operators::Vector{Symbol}
multivariate_operator_to_id::Dict{Symbol,Int}
multivariate_user_operator_start::Int
registered_multivariate_operators::Vector{
MOI.Nonlinear._MultivariateOperator,
}
# NODE_LOGIC
logic_operators::Vector{Symbol}
logic_operator_to_id::Dict{Symbol,Int}
# NODE_COMPARISON
comparison_operators::Vector{Symbol}
comparison_operator_to_id::Dict{Symbol,Int}
function OperatorRegistry()
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
logic_operators = [:&&, :||]
comparison_operators = [:<=, :(==), :>=, :<, :>]
return new(
# NODE_CALL_UNIVARIATE
univariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(univariate_operators)
),
length(univariate_operators),
_UnivariateOperator[],
# NODE_CALL
multivariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(multivariate_operators)
),
length(multivariate_operators),
MOI.Nonlinear._MultivariateOperator[],
# NODE_LOGIC
logic_operators,
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
# NODE_COMPARISON
comparison_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(comparison_operators)
),
)
end
end

function eval_logic_function(
::OperatorRegistry,
op::Symbol,
Expand All @@ -34,6 +186,23 @@ function eval_logic_function(
end
end

function _generate_eval_univariate()
exprs = map(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS) do op
return :(
return (
value_deriv_and_second($op, x)[1],
value_deriv_and_second($op, x)[2],
)
)
end
return Nonlinear._create_binary_switch(1:length(exprs), exprs)
end

@eval @inline function _eval_univariate(id, x::T) where {T}
$(_generate_eval_univariate())
return error("Invalid id for univariate operator: $id")
end

function eval_multivariate_function(
registry::OperatorRegistry,
op::Symbol,
Expand Down Expand Up @@ -165,17 +334,44 @@ function eval_multivariate_hessian(
return true
end

function eval_univariate_function(operator::_UnivariateOperator, x::T) where {T}
ret = operator.f(x)
check_return_type(T, ret)
return ret::T
end

function eval_univariate_gradient(operator::_UnivariateOperator, x::T) where {T}
ret = operator.f′(x)
check_return_type(T, ret)
return ret::T
end

function eval_univariate_hessian(operator::_UnivariateOperator, x::T) where {T}
ret = operator.f′′(x)
check_return_type(T, ret)
return ret::T
end

function eval_univariate_function_and_gradient(
operator::_UnivariateOperator,
x::T,
) where {T}
ret_f = eval_univariate_function(operator, x)
ret_f′ = eval_univariate_gradient(operator, x)
return ret_f, ret_f′
end

function eval_univariate_function_and_gradient(
registry::OperatorRegistry,
id::Integer,
x::T,
) where {T}
if id <= registry.univariate_user_operator_start
return Nonlinear._eval_univariate(id, x)::Tuple{T,T}
return _eval_univariate(id, x)::Tuple{T,T}
end
offset = id - registry.univariate_user_operator_start
operator = registry.registered_univariate_operators[offset]
return Nonlinear.eval_univariate_function_and_gradient(operator, x)
return eval_univariate_function_and_gradient(operator, x)
end

function eval_multivariate_gradient(
Expand Down
51 changes: 0 additions & 51 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,57 +133,6 @@ struct _FunctionStorage
end
end

struct OperatorRegistry
# NODE_CALL_UNIVARIATE
univariate_operators::Vector{Symbol}
univariate_operator_to_id::Dict{Symbol,Int}
univariate_user_operator_start::Int
registered_univariate_operators::Vector{MOI.Nonlinear._UnivariateOperator}
# NODE_CALL_MULTIVARIATE
multivariate_operators::Vector{Symbol}
multivariate_operator_to_id::Dict{Symbol,Int}
multivariate_user_operator_start::Int
registered_multivariate_operators::Vector{
MOI.Nonlinear._MultivariateOperator,
}
# NODE_LOGIC
logic_operators::Vector{Symbol}
logic_operator_to_id::Dict{Symbol,Int}
# NODE_COMPARISON
comparison_operators::Vector{Symbol}
comparison_operator_to_id::Dict{Symbol,Int}
function OperatorRegistry()
univariate_operators = copy(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
logic_operators = [:&&, :||]
comparison_operators = [:<=, :(==), :>=, :<, :>]
return new(
# NODE_CALL_UNIVARIATE
univariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(univariate_operators)
),
length(univariate_operators),
MOI.Nonlinear._UnivariateOperator[],
# NODE_CALL
multivariate_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(multivariate_operators)
),
length(multivariate_operators),
MOI.Nonlinear._MultivariateOperator[],
# NODE_LOGIC
logic_operators,
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
# NODE_COMPARISON
comparison_operators,
Dict{Symbol,Int}(
op => i for (i, op) in enumerate(comparison_operators)
),
)
end
end

"""
Model()

Expand Down
Loading
Loading