Skip to content
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
*.jl.mem
docs/build
docs/site
Manifest.toml
Manifest*.toml
4 changes: 4 additions & 0 deletions src/NLPModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ for f in ["utils", "api", "counters", "meta", "show", "tools"]
include("nls/$f.jl")
end

include("nlp/batch/api.jl")
include("nlp/batch/foreach.jl")
include("nlp/batch/inplace.jl")

end # module
82 changes: 82 additions & 0 deletions src/nlp/batch/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
export AbstractBatchNLPModel

abstract type AbstractBatchNLPModel end

function NLPModels.increment!(bnlp::AbstractBatchNLPModel, fun::Symbol)
NLPModels.increment!(bnlp, Val(fun))
end

for fun in (
:batch_obj,
:batch_grad,
:batch_grad!,
:batch_objgrad,
:batch_objgrad!,
:batch_objcons,
:batch_objcons!,
:batch_cons,
:batch_cons!,
:batch_cons_lin,
:batch_cons_lin!,
:batch_cons_nln,
:batch_cons_nln!,
:batch_jth_con,
:batch_jth_congrad,
:batch_jth_congrad!,
:batch_jth_sparse_congrad,
:batch_jac_structure!,
:batch_jac_structure,
:batch_jac_coord!,
:batch_jac_coord,
:batch_jac,
:batch_jprod,
:batch_jprod!,
:batch_jtprod,
:batch_jtprod!,
:batch_jac_op,
:batch_jac_op!,
:batch_jac_lin_structure!,
:batch_jac_lin_structure,
:batch_jac_lin_coord!,
:batch_jac_lin_coord,
:batch_jac_lin,
:batch_jprod_lin,
:batch_jprod_lin!,
:batch_jtprod_lin,
:batch_jtprod_lin!,
:batch_jac_lin_op,
:batch_jac_lin_op!,
:batch_jac_nln_structure!,
:batch_jac_nln_structure,
:batch_jac_nln_coord!,
:batch_jac_nln_coord,
:batch_jac_nln,
:batch_jprod_nln,
:batch_jprod_nln!,
:batch_jtprod_nln,
:batch_jtprod_nln!,
:batch_jac_nln_op,
:batch_jac_nln_op!,
:batch_jth_hess_coord,
:batch_jth_hess_coord!,
:batch_jth_hess,
:batch_jth_hprod,
:batch_jth_hprod!,
:batch_ghjvprod,
:batch_ghjvprod!,
:batch_hess_structure!,
:batch_hess_structure,
:batch_hess_coord!,
:batch_hess_coord,
:batch_hess,
:batch_hprod,
:batch_hprod!,
:batch_hess_op,
:batch_hess_op!,
:batch_varscale,
:batch_lagscale,
:batch_conscale,
)
@eval export $fun
@eval function $fun end
end
253 changes: 253 additions & 0 deletions src/nlp/batch/foreach.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
export ForEachBatchNLPModel
struct ForEachBatchNLPModel{M} <: AbstractBatchNLPModel
models::M
counters::Counters
batch_size::Int
end
function ForEachBatchNLPModel(models::M) where {M}
isempty(models) && error("Cannot create ForEachBatchNLPModel from empty collection.")
ForEachBatchNLPModel{M}(models, Counters(), length(models))
end
Base.length(vnlp::ForEachBatchNLPModel) = vnlp.batch_size
Base.getindex(vnlp::ForEachBatchNLPModel, i::Integer) = vnlp.models[i]
Base.iterate(vnlp::ForEachBatchNLPModel, state::Integer = 1) = iterate(vnlp.models, state)


function _batch_map(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
results = []
resize!(results, n)
for i = 1:n
args_i = (x[i] for x in xs)
results[i] = f(bnlp[i], args_i...)
end
return results
end

function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.")
@lencheck_tup n xs
for i = 1:n
args_i = (x[i] for x in xs)
f(bnlp[i], args_i...)
end
return xs[end]
end

function _batch_map_weight(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
@lencheck n obj_weights
results = []
resize!(results, n)
for i = 1:n
args_i = (x[i] for x in xs)
results[i] = f(bnlp[i], args_i...; obj_weight = obj_weights[i])
end
return results
end

function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
length(xs) == 0 && error("Cannot call _batch_map_weight! without providing arguments.")
@lencheck_tup n xs
@lencheck n obj_weights
for i = 1:n
args_i = (x[i] for x in xs)
f(bnlp[i], args_i...; obj_weight = obj_weights[i])
end
return xs[end]
end

function _batch_map_tuple(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
results = _batch_map(f, bnlp, xs...)

first_result = first(results)
T1, T2 = typeof(first_result[1]), typeof(first_result[2])
vec1, vec2 = Vector{T1}(undef, n), Vector{T2}(undef, n)
for i = 1:n
vec1[i], vec2[i] = results[i]
end
return vec1, vec2
end

function _batch_map_tuple!(f::F, bnlp::ForEachBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
@lencheck n outputs
firsts = []
resize!(firsts, n)
for i = 1:n
args_i = (x[i] for x in xs)
firsts[i], _ = f(bnlp[i], args_i..., outputs[i])
end
return firsts, outputs
end

for fun in fieldnames(Counters)
@eval function NLPModels.increment!(bnlp::ForEachBatchNLPModel, ::Val{$(Meta.quot(fun))})
# sub-model counters are already incremented since we call their methods
bnlp.counters.$fun += 1
end
end
for (batch_fun, fun) in (
(:batch_jac_structure, :jac_structure),
(:batch_jac_lin_structure, :jac_lin_structure),
(:batch_jac_nln_structure, :jac_nln_structure),
(:batch_hess_structure, :hess_structure),
(:batch_varscale, :varscale),
(:batch_lagscale, :lagscale),
(:batch_conscale, :conscale),
)
@eval $batch_fun(bnlp::ForEachBatchNLPModel) = _batch_map($fun, bnlp)
end
for (batch_fun, fun) in (
(:batch_obj, :obj),
(:batch_grad, :grad),
(:batch_cons, :cons),
(:batch_cons_lin, :cons_lin),
(:batch_cons_nln, :cons_nln),
(:batch_jac, :jac),
(:batch_jac_lin, :jac_lin),
(:batch_jac_nln, :jac_nln),
(:batch_jac_lin_coord, :jac_lin_coord),
(:batch_jac_coord, :jac_coord),
(:batch_jac_nln_coord, :jac_nln_coord),
)
@eval $batch_fun(bnlp::ForEachBatchNLPModel, x) = _batch_map($fun, bnlp, x)
end
for (batch_fun, fun) in (
(:batch_jprod, :jprod),
(:batch_jtprod, :jtprod),
(:batch_jprod_nln, :jprod_nln),
(:batch_jtprod_nln, :jtprod_nln),
(:batch_jprod_lin, :jprod_lin),
(:batch_jtprod_lin, :jtprod_lin),
)
@eval $batch_fun(bnlp::ForEachBatchNLPModel, x, y) = _batch_map($fun, bnlp, x, y)
end
for (batch_fun, fun) in (
(:batch_ghjvprod, :ghjvprod),
)
@eval $batch_fun(bnlp::ForEachBatchNLPModel, x, y, z) = _batch_map($fun, bnlp, x, y, z)
end

batch_jac_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(jac_structure!, bnlp, rowss, colss)
batch_jac_lin_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(jac_lin_structure!, bnlp, rowss, colss)
batch_jac_nln_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(jac_nln_structure!, bnlp, rowss, colss)
batch_hess_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(hess_structure!, bnlp, rowss, colss)
batch_jac_lin_coord!(bnlp::ForEachBatchNLPModel, xs, valss) =
_batch_map!(jac_lin_coord!, bnlp, xs, valss)
batch_grad!(bnlp::ForEachBatchNLPModel, xs, gs) =
_batch_map!(grad!, bnlp, xs, gs)
batch_cons!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map!(cons!, bnlp, xs, cs)
batch_cons_lin!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map!(cons_lin!, bnlp, xs, cs)
batch_cons_nln!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map!(cons_nln!, bnlp, xs, cs)
batch_jac_coord!(bnlp::ForEachBatchNLPModel, xs, valss) =
_batch_map!(jac_coord!, bnlp, xs, valss)
batch_jac_nln_coord!(bnlp::ForEachBatchNLPModel, xs, valss) =
_batch_map!(jac_nln_coord!, bnlp, xs, valss)
batch_jprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) =
_batch_map!(jprod!, bnlp, xs, vs, Jvs)
batch_jtprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) =
_batch_map!(jtprod!, bnlp, xs, vs, Jtvs)
batch_jprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) =
_batch_map!(jprod_nln!, bnlp, xs, vs, Jvs)
batch_jtprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) =
_batch_map!(jtprod_nln!, bnlp, xs, vs, Jtvs)
batch_jprod_lin!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) =
_batch_map!(jprod_lin!, bnlp, xs, vs, Jvs)
batch_jtprod_lin!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) =
_batch_map!(jtprod_lin!, bnlp, xs, vs, Jtvs)
batch_ghjvprod!(bnlp::ForEachBatchNLPModel, xs, gs, vs, gHvs) =
_batch_map!(ghjvprod!, bnlp, xs, gs, vs, gHvs)

## jth
batch_jth_con(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_con(m, x, j), bnlp, xs)
batch_jth_congrad(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_congrad(m, x, j), bnlp, xs)
batch_jth_sparse_congrad(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_sparse_congrad(m, x, j), bnlp, xs)
batch_jth_hess_coord(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_hess_coord(m, x, j), bnlp, xs)
batch_jth_hess(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_hess(m, x, j), bnlp, xs)
batch_jth_hprod(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer) =
_batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs)

batch_jth_congrad!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) =
_batch_map!((m, x, out) -> jth_congrad!(m, x, j, out), bnlp, xs, outputs)
batch_jth_hess_coord!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) =
_batch_map!((m, x, out) -> jth_hess_coord!(m, x, j, out), bnlp, xs, outputs)
batch_jth_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer, outputs) =
_batch_map!((m, x, v, out) -> jth_hprod!(m, x, v, j, out), bnlp, xs, vs, outputs)

# hess (need to treat obj_weight) FIXME: obj_weights is required in batch API
batch_hprod(bnlp::ForEachBatchNLPModel, xs, vs; obj_weights) =
_batch_map_weight(hprod, bnlp, obj_weights, xs, vs)
batch_hprod(bnlp::ForEachBatchNLPModel, xs, ys, vs; obj_weights) =
_batch_map_weight(hprod, bnlp, obj_weights, xs, ys, vs)
batch_hess_coord(bnlp::ForEachBatchNLPModel, xs; obj_weights) =
_batch_map_weight(hess_coord, bnlp, obj_weights, xs)
batch_hess_coord(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) =
_batch_map_weight(hess_coord, bnlp, obj_weights, xs, ys)
batch_hess_op(bnlp::ForEachBatchNLPModel, xs; obj_weights) =
_batch_map_weight(hess_op, bnlp, obj_weights, xs)
batch_hess_op(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) =
_batch_map_weight(hess_op, bnlp, obj_weights, xs, ys)

batch_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, outputs; obj_weights) =
_batch_map_weight!(hprod!, bnlp, obj_weights, xs, vs, outputs)
batch_hprod!(bnlp::ForEachBatchNLPModel, xs, ys, vs, outputs; obj_weights) =
_batch_map_weight!(hprod!, bnlp, obj_weights, xs, ys, vs, outputs)
batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, outputs; obj_weights) =
_batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, outputs)
batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, ys, outputs; obj_weights) =
_batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, ys, outputs)
batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, Hvs; obj_weights) =
_batch_map_weight(hess_op!, bnlp, obj_weights, xs, Hvs)
batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, ys, Hvs; obj_weights) =
_batch_map_weight(hess_op!, bnlp, obj_weights, xs, ys, Hvs)

batch_hess(bnlp::ForEachBatchNLPModel, xs; obj_weights) =
_batch_map_weight(hess, bnlp, obj_weights, xs)
batch_hess(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) =
_batch_map_weight(hess, bnlp, obj_weights, xs, ys)

## operators
batch_jac_op(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_op, bnlp, xs)
batch_jac_lin_op(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_lin_op, bnlp, xs)
batch_jac_nln_op(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_nln_op, bnlp, xs)

batch_jac_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
_batch_map(jac_op!, bnlp, xs, Jvs, Jtvs)
batch_jac_lin_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
_batch_map(jac_lin_op!, bnlp, xs, Jvs, Jtvs)
batch_jac_nln_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
_batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs)

## tuple functions
batch_objgrad(bnlp::ForEachBatchNLPModel, xs) =
_batch_map_tuple(objgrad, bnlp, xs)
batch_objcons(bnlp::ForEachBatchNLPModel, xs) =
_batch_map_tuple(objcons, bnlp, xs)

batch_objgrad!(bnlp::ForEachBatchNLPModel, xs, gs) =
_batch_map_tuple!(objgrad!, bnlp, gs, xs)
batch_objcons!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map_tuple!(objcons!, bnlp, cs, xs)
Loading
Loading