Skip to content
Merged

Dev #33

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HybridVariationalInference"
uuid = "a108c475-a4e2-4021-9a84-cfa7df242f64"
authors = ["Thomas Wutzler <twutz@bgc-jena.mpg.de> and contributors"]
version = "0.2"
version = "0.2.0"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down Expand Up @@ -31,6 +31,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Expand Down Expand Up @@ -69,14 +70,15 @@ MLUtils = "0.4.5"
Missings = "1.2.0"
NaNMath = "1.1.3"
Optimisers = "0.4.6"
Optimization = "3.19.3, 4"
Optimization = "3.11, 4"
Random = "1.10.0"
SimpleChains = "0.4"
StableRNGs = "1.0.2"
StaticArrays = "1.9.13"
StatsBase = "0.34.4"
StatsFuns = "1.3.2"
Test = "1.10"
UnPack = "1.0.2"
Zygote = "0.7.10"
julia = "1.10"

Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ of the posterior. It returns a NamedTuple of
- the machine learning model parameters (usually weights), $\phi_g$
- means of the global parameters, $\phi_P = \mu_{\zeta_P}$ at transformed
unconstrained scale
- additional parameters, $\phi_{unc}$ of the posterior, $q(\zeta)$, such as
coefficients that describe the scaling of variance with magnitude
and coefficients that parameterize the choleski-factor or the correlation matrix.
- additional parameters, $\phi_{ϕq}$ of the posterior, $q(\zeta)$, such as
- coefficients that describe the scaling of variance with magnitude
- coefficients that parameterize the choleski-factor or the correlation matrix
- mean of global parameters at unconstrained scale
- `θP`: predicted means of the global parameters, $\theta_P$
- `resopt`: the original result object of the optimizer (useful for debugging)

Expand Down
2 changes: 2 additions & 0 deletions docs/src/tutorials/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
55 changes: 25 additions & 30 deletions docs/src/tutorials/basic_cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using SimpleChains
using StatsFuns
using MLUtils
using DistributionFits
using UnPack
```

Next, specify many moving parts of the Hybrid variational inference (HVI)
Expand All @@ -33,9 +34,7 @@ $$
``` julia
function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
# extract parameters not depending on order, i.e whether they are in θP or θM
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
CA.getdata(θc[par])::ET
end
@unpack r0, r1, K1, K2 = θc
r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
end
```
Expand Down Expand Up @@ -157,15 +156,16 @@ the problem below.

### Providing data in batches

HVI uses `MLUtils.DataLoader` to provide baches of the data during each
HVI uses `MLUtils.DataLoader` to provide batches of the data during each
iteration of the solver. In addition to the data, it provides an
index to the sites inside a tuple.

``` julia
n_site = size(y_o,2)
n_batch = 20
train_dataloader = MLUtils.DataLoader(
(xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false)
(CA.getdata(xM), CA.getdata(xP), y_o, y_unc, 1:n_site),
batchsize=n_batch, partial=false)
```

## The Machine-Learning model
Expand Down Expand Up @@ -211,7 +211,7 @@ However, for simplicity, a [`NormalScalingModelApplicator`](@ref)
is fitted to the transformed 5% and 95% quantiles of the original prior.

``` julia
priorsM = [priors_dict[k] for k in keys(θM)]
priorsM = Tuple(priors_dict[k] for k in keys(θM))
lowers, uppers = get_quantile_transformed(priorsM, transM)
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)
```
Expand All @@ -231,8 +231,9 @@ invocation of the process based model (PBM), defined at the beginning.

``` julia
f_batch = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1])
ϕq0 = init_hybrid_ϕq(MeanHVIApproximation(), θP, θM, transP)

prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0,
prob = HybridProblem(θM, ϕq0, g_chain_scaled, ϕg0,
f_batch, priors_dict, py,
transM, transP, train_dataloader, n_covar, n_site, n_batch)
```
Expand Down Expand Up @@ -267,7 +268,7 @@ Then the solver is applied to the problem using [`solve`](@ref)
for a given number of iterations or epochs.
For this tutorial, we additionally specify that the function to transfer structures to
the GPU is the identity function, so that all stays on the CPU, and this tutorial
hence does not require ad GPU or GPU livraries.
hence does not require ad GPU or GPU libraries.

Among the return values are
- `probo`: A copy of the HybridProblem, with updated optimized parameters
Expand All @@ -276,7 +277,7 @@ will help analyzing the results.

## Using a population-level process-based model

So far, the process-based model ram for each single site.
So far, the process-based model ran for each single site.
For this simple model, some performance grains result from matrix-computations
when running the model for all sites within one batch simultaneously.

Expand All @@ -289,29 +290,25 @@ one site. For the drivers and predictions, one column corresponds to one site.
``` julia
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
# extract several covariates from xP
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
S1 = (CA.getdata(xPc[:S1,:])::ST)
S2 = (CA.getdata(xPc[:S2,:])::ST)
S1 = view(xPc, Val(:S1), :)
S2 = view(xPc, Val(:S2), :)
#
# extract the parameters as row-repeated vectors
n_obs = size(S1, 1)
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
p1 = CA.getdata(θc[:, par]) ::VT
repeat(p1', n_obs) # matrix: same for each concentration row in S1
end
#
# θc[:,:r0] is parameter r0 for each site in batch
# dot-multiplication of full matrix times row-vector repeats for each observation row
# also introduces zero for missing observations, leading to zero gradient there
is_valid = isfinite.(S1) .&& isfinite.(S2)
r0 = is_valid .* CA.getdata(θc[:, Val(:r0)])'
r1 = is_valid .* CA.getdata(θc[:, Val(:r1)])'
K1 = is_valid .* CA.getdata(θc[:, Val(:K1)])'
K2 = is_valid .* CA.getdata(θc[:, Val(:K2)])'
# each variable is a matrix (n_obs x n_site)
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
end
```

Again, the function should not rely on the order of parameters but use symbolic indexing
to extract the parameter vectors. For type stability of this symbolic indexing,
it uses a workaround to get the type of a single row.
Similarly, it uses type hints to index into the drivers, `xPc`, to extract
sub-matrices by symbols. Alternatively, here it could rely on the structure and
ordering of the columns in `xPc`.
to extract the parameter vectors.

A corresponding [`PBMPopulationApplicator`](@ref) transforms calls with
partitioned global and site parameters to calls of this matrix version of the PBM.
Expand All @@ -323,11 +320,9 @@ probo_sites = HybridProblem(probo; f_batch)
```

For numerical efficiency, the number of sites within one batch is part of the
`PBMPopulationApplicator`. Hence, we have two different functions, one applied
to a batch of site, and another applied to all sites.

As a test of the new applicator, the results are refined by running a few more
epochs of the optimization.
`PBMPopulationApplicator`. The problem stores an applicator for `n_batch` sites,
however, an applicator for `n_site_pred` sites can be obtained by
`create_nsite_applicator(f_batch, n_site_pred)`.

``` julia
(; probo) = solve(probo_sites, solver; rng,
Expand All @@ -344,7 +339,7 @@ in the following [Inspect results of fitted problem](@ref) tutorial.
In order to use the results from this tutorial in other tutorials,
the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file.

Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
Before the problem is updated, so that it uses the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
of the PBM in module `DoubleMM` rather than
module `Main` to allow for easier reloading with JLD2.

Expand Down
56 changes: 26 additions & 30 deletions docs/src/tutorials/basic_cpu.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using SimpleChains
using StatsFuns
using MLUtils
using DistributionFits
using UnPack
```

Next, specify many moving parts of the Hybrid variational inference (HVI)
Expand All @@ -42,9 +43,7 @@ $$
```{julia}
function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
# extract parameters not depending on order, i.e whether they are in θP or θM
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
CA.getdata(θc[par])::ET
end
@unpack r0, r1, K1, K2 = θc
r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
end
```
Expand Down Expand Up @@ -166,15 +165,16 @@ the problem below.

### Providing data in batches

HVI uses `MLUtils.DataLoader` to provide baches of the data during each
HVI uses `MLUtils.DataLoader` to provide batches of the data during each
iteration of the solver. In addition to the data, it provides an
index to the sites inside a tuple.

```{julia}
n_site = size(y_o,2)
n_batch = 20
train_dataloader = MLUtils.DataLoader(
(xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false)
(CA.getdata(xM), CA.getdata(xP), y_o, y_unc, 1:n_site),
batchsize=n_batch, partial=false)
```

## The Machine-Learning model
Expand Down Expand Up @@ -220,7 +220,7 @@ However, for simplicity, a [`NormalScalingModelApplicator`](@ref)
is fitted to the transformed 5% and 95% quantiles of the original prior.

```{julia}
priorsM = [priors_dict[k] for k in keys(θM)]
priorsM = Tuple(priors_dict[k] for k in keys(θM))
lowers, uppers = get_quantile_transformed(priorsM, transM)
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)
```
Expand All @@ -241,8 +241,9 @@ invocation of the process based model (PBM), defined at the beginning.

```{julia}
f_batch = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1])
ϕq0 = init_hybrid_ϕq(MeanHVIApproximation(), θP, θM, transP)

prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0,
prob = HybridProblem(θM, ϕq0, g_chain_scaled, ϕg0,
f_batch, priors_dict, py,
transM, transP, train_dataloader, n_covar, n_site, n_batch)
```
Expand Down Expand Up @@ -302,7 +303,7 @@ Then the solver is applied to the problem using [`solve`](@ref)
for a given number of iterations or epochs.
For this tutorial, we additionally specify that the function to transfer structures to
the GPU is the identity function, so that all stays on the CPU, and this tutorial
hence does not require ad GPU or GPU livraries.
hence does not require ad GPU or GPU libraries.

Among the return values are
- `probo`: A copy of the HybridProblem, with updated optimized parameters
Expand All @@ -311,7 +312,7 @@ Among the return values are

## Using a population-level process-based model

So far, the process-based model ram for each single site.
So far, the process-based model ran for each single site.
For this simple model, some performance grains result from matrix-computations
when running the model for all sites within one batch simultaneously.

Expand All @@ -323,31 +324,28 @@ one site. For the drivers and predictions, one column corresponds to one site.


```{julia}
using StaticArrays
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
# extract several covariates from xP
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
S1 = (CA.getdata(xPc[:S1,:])::ST)
S2 = (CA.getdata(xPc[:S2,:])::ST)
S1 = view(xPc, Val(:S1), :)
S2 = view(xPc, Val(:S2), :)
#
# extract the parameters as row-repeated vectors
n_obs = size(S1, 1)
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
p1 = CA.getdata(θc[:, par]) ::VT
repeat(p1', n_obs) # matrix: same for each concentration row in S1
end
#
# θc[:,:r0] is parameter r0 for each site in batch
# dot-multiplication of full matrix times row-vector repeats for each observation row
# also introduces zero for missing observations, leading to zero gradient there
is_valid = isfinite.(S1) .&& isfinite.(S2)
r0 = is_valid .* CA.getdata(θc[:, Val(:r0)])'
r1 = is_valid .* CA.getdata(θc[:, Val(:r1)])'
K1 = is_valid .* CA.getdata(θc[:, Val(:K1)])'
K2 = is_valid .* CA.getdata(θc[:, Val(:K2)])'
# each variable is a matrix (n_obs x n_site)
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
end
```

Again, the function should not rely on the order of parameters but use symbolic indexing
to extract the parameter vectors. For type stability of this symbolic indexing,
it uses a workaround to get the type of a single row.
Similarly, it uses type hints to index into the drivers, `xPc`, to extract
sub-matrices by symbols. Alternatively, here it could rely on the structure and
ordering of the columns in `xPc`.
to extract the parameter vectors.

A corresponding [`PBMPopulationApplicator`](@ref) transforms calls with
partitioned global and site parameters to calls of this matrix version of the PBM.
Expand All @@ -359,11 +357,9 @@ probo_sites = HybridProblem(probo; f_batch)
```

For numerical efficiency, the number of sites within one batch is part of the
`PBMPopulationApplicator`. Hence, we have two different functions, one applied
to a batch of site, and another applied to all sites.

As a test of the new applicator, the results are refined by running a few more
epochs of the optimization.
`PBMPopulationApplicator`. The problem stores an applicator for `n_batch` sites,
however, an applicator for `n_site_pred` sites can be obtained by
`create_nsite_applicator(f_batch, n_site_pred)`.

```{julia}
(; probo) = solve(probo_sites, solver; rng,
Expand All @@ -379,7 +375,7 @@ in the following [Inspect results of fitted problem](@ref) tutorial.
In order to use the results from this tutorial in other tutorials,
the updated `probo` `HybridProblem` and the interpreters are saved to a JLD2 file.

Before the problem is updated to use the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
Before the problem is updated, so that it uses the redefinition [`DoubleMM.f_doubleMM_sites`](@ref)
of the PBM in module `DoubleMM` rather than
module `Main` to allow for easier reloading with JLD2.

Expand Down
Loading
Loading