Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 12 additions & 24 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
version = "0.1.1"

[[deps.BioCore]]
deps = ["Automa", "BufferedStreams", "YAML"]
git-tree-sha1 = "476edbf4ef94594fff430a84ca96f86cb2327a71"
uuid = "37cfa864-2cd6-5c12-ad9e-b6597d696c81"
version = "2.0.5"
[[deps.BioGenerics]]
deps = ["TranscodingStreams"]
git-tree-sha1 = "7bbc085aebc6faa615740b63756e4986c9e85a70"
uuid = "47718e42-2ac5-11e9-14af-e5595289c2ea"
version = "0.1.4"

[[deps.BitFlags]]
git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b"
Expand Down Expand Up @@ -2008,12 +2008,6 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"]
StatsFunsChainRulesCoreExt = "ChainRulesCore"
StatsFunsInverseFunctionsExt = "InverseFunctions"

[[deps.StringEncodings]]
deps = ["Libiconv_jll"]
git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb"
uuid = "69024149-9ee7-55f6-a4c4-859efe599b68"
version = "0.3.7"

[[deps.StringManipulation]]
deps = ["PrecompileTools"]
git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
Expand Down Expand Up @@ -2177,10 +2171,10 @@ uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.3"

[[deps.VariantCallFormat]]
deps = ["Automa", "BGZFStreams", "BioCore", "BufferedStreams"]
git-tree-sha1 = "f73ea34d3085cdbf6a18fa4c4b690e0f4a147730"
deps = ["Automa", "BGZFStreams", "BioGenerics", "BufferedStreams"]
git-tree-sha1 = "96fbe09c9e3b488666c883772fed8f6c1256c714"
uuid = "28eba6e3-a997-4ad9-87c6-d933b8bca6c1"
version = "0.5.5"
version = "0.5.6"

[[deps.VectorizationBase]]
deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static", "StaticArrayInterface"]
Expand Down Expand Up @@ -2219,9 +2213,9 @@ version = "1.1.34+0"

[[deps.XZ_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "31c421e5516a6248dfb22c194519e37effbf1f30"
git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632"
uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800"
version = "5.6.1+0"
version = "5.4.6+0"

[[deps.Xorg_libX11_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"]
Expand Down Expand Up @@ -2271,22 +2265,16 @@ git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77"
uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10"
version = "1.5.0+0"

[[deps.YAML]]
deps = ["Base64", "Dates", "Printf", "StringEncodings"]
git-tree-sha1 = "e6330e4b731a6af7959673621e91645eb1356884"
uuid = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
version = "0.4.9"

[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+0"

[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c"
git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b"
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
version = "1.5.5+0"
version = "1.5.6+0"

[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
Expand Down
62 changes: 46 additions & 16 deletions src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ end
Calculates the log density of β based on a spiek and slab prior
"""
function log_prior(β::Vector, σ2_β::Vector, p_causal::Vector)

P = length(β)
# prob_slab = 0.10
# L = prob_slab * 1_000
Expand All @@ -26,6 +25,26 @@ function log_prior(β::Vector, σ2_β::Vector, p_causal::Vector)
return sum(logprobs)
end

## TO COMPLETE ##
function log_prior_lse(β::Vector, σ2_β::Vector, p_causal::Vector)

P = length(β)
spike_σ2 = 1e-6
slab_dist = Normal.(0, sqrt.(σ2_β .+ spike_σ2))
spike_dist = Normal.(0, sqrt(spike_σ2))

# compute log probabilities for slab and spike components using vectorized operations
log_prob_slab = logpdf.(slab_dist, β) .+ log.(p_causal)
log_prob_spike = logpdf.(spike_dist, β) .+ log.(1 .- p_causal)

# applying the Log-Sum-Exp trick using vectorized operations
max_log_prob = max.(log_prob_slab, log_prob_spike)
logprobs = max_log_prob .+ log.(exp.(log_prob_slab .- max_log_prob) .+ exp.(log_prob_spike .- max_log_prob))

# sum of log probabilities
return sum(logprobs)
end

"""
rss(β, coef, SE, R)
Calculate the summary statistic RSS likelihood
Expand Down Expand Up @@ -114,7 +133,8 @@ joint_log_prob(
"""
joint_log_prob(β::Vector, coef::Vector, SE::Vector, R::Matrix, σ2_β::Vector, p_causal::Vector, to) = rss(β, coef, SE, R, to) + log_prior(β, σ2_β, p_causal)

joint_log_prob(β::Vector, coef::Vector, Σ::AbstractPDMat, SRSinv::Matrix, σ2_β::Vector, p_causal::Vector, to) = rss(β, coef, Σ, SRSinv, to) + log_prior(β, σ2_β, p_causal)
#joint_log_prob(β::Vector, coef::Vector, Σ::AbstractPDMat, SRSinv::Matrix, σ2_β::Vector, p_causal::Vector, to) = rss(β, coef, Σ, SRSinv, to) + log_prior(β, σ2_β, p_causal)
joint_log_prob(β::Vector, coef::Vector, Σ::AbstractPDMat, SRSinv::Matrix, σ2_β::Vector, p_causal::Vector, to) = rss(β, coef, Σ, SRSinv, to) + log_prior_lse(β, σ2_β, p_causal)

"""
elbo(z, q_μ, log_q_var, coef, SE, R, σ2_β, p_causal)
Expand All @@ -133,21 +153,31 @@ elbo(
)
```
"""
function elbo(z::Vector, q_μ::Vector, log_q_var::Vector, coef::Vector, SE::Vector, R::AbstractArray, σ2_β::Vector, p_causal::Vector, to)
q_var = @timeit to "q_var" exp.(log_q_var)
q = @timeit to "q" MvNormal(q_μ, Diagonal(q_var))
q_sd = @timeit to "q_sd" sqrt.(q_var)
ϕ = @timeit to "ϕ" q_μ .+ q_sd .* z
# γ = compute_γ(q_μ, q_var)
# jl = joint_log_prob(γ .* ϕ, coef, SE, R)
jl = @timeit to "joint_log_prob" joint_log_prob(ϕ, coef, SE, R, σ2_β, p_causal, to)
q = @timeit to "logpd" logpdf(q, ϕ)
# jac = prod(z)
return (jl - q)
end
#function elbo(z::Vector, q_μ::Vector, log_q_var::Vector, coef::Vector, SE::Vector, R::AbstractArray, σ2_β::Vector, p_causal::Vector, to)
# q_var = @timeit to "q_var" exp.(log_q_var)
# q = @timeit to "q" MvNormal(q_μ, Diagonal(q_var))
# q_sd = @timeit to "q_sd" sqrt.(q_var)
# ϕ = @timeit to "ϕ" q_μ .+ q_sd .* z
# # γ = compute_γ(q_μ, q_var)
# # jl = joint_log_prob(γ .* ϕ, coef, SE, R)
# jl = @timeit to "joint_log_prob" joint_log_prob(ϕ, coef, SE, R, σ2_β, p_causal, to)
# q = @timeit to "logpd" logpdf(q, ϕ)
# # jac = prod(z)
# return (jl - q)
#end

#function elbo(z::Vector, q_μ::Vector, log_q_var::Vector, coef::Vector, Σ::AbstractPDMat, SRSinv::Matrix, σ2_β::Vector, p_causal::Vector, to)

function elbo(z::Vector, q_μ::Vector, log_q_var::Vector, coef::Vector, Σ::AbstractPDMat, SRSinv::Matrix, σ2_β::Vector, p_causal::Vector, to)
q_var = @timeit to "q_var" exp.(log_q_var)
"""
elbo(z, q_μ, q_var, coef, Σ, SRSinv, σ2_β, p_causal, to)

Calculate ELBO sampled from MC sampling
"""
function elbo(z::Vector, q_μ::Vector, q_var::Vector, coef::Vector, Σ::AbstractPDMat, SRSinv::Matrix, σ2_β::Vector, p_causal::Vector, to)
#q_var = @timeit to "q_var" exp.(log_q_var)
#if (any(x->x<0, q_var))
# println("Negative q_var: ", q_var[findall(x->x<0, q_var)]) # Debug output
#end
q = @timeit to "q" MvNormal(q_μ, Diagonal(q_var))
q_sd = @timeit to "q_sd" sqrt.(q_var)
ϕ = @timeit to "ϕ" q_μ .+ q_sd .* z
Expand Down
Loading