Skip to content
2 changes: 2 additions & 0 deletions mwe/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
54 changes: 54 additions & 0 deletions mwe/mwe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using BridgeStan

stan_code = """
functions {
real partial_sum(array[] real slice, int start, int end,
tuple(real, int) params) {
real mu = params.1;
int K = params.2;
real lp = 0;
for (i in 1:size(slice)) {
lp += normal_lpdf(slice[i] | mu, K);
}
return lp;
}
}
data {
int N;
int K;
array[N] real y;
}
parameters {
real mu;
}
model {
mu ~ normal(0, 10);
target += reduce_sum(partial_sum, y, 1, (mu, K));
}
"""

stan_math = ENV["MWE_RUN_DIR"]
label = ENV["MWE_LABEL"]

workdir = mktempdir()
stan_file = joinpath(workdir, "mwe.stan")
write(stan_file, stan_code)

lib = compile_model(stan_file; make_args=["MATH=$stan_math/", "STAN_THREADS=true"])

data = """{"N": 5, "K": 2, "y": [1.0, 2.0, 3.0, 4.0, 5.0]}"""
sm = StanModel(lib, data)

params = [3.0] # mu (unconstrained)
lp = log_density(sm, params)
lp_grad, grad = log_density_gradient(sm, params)

println("[$label] log_density = $lp")
println("[$label] gradient = $grad")

@assert isfinite(lp) "log_density should be finite"
@assert all(isfinite, grad) "gradient should be finite"
@assert lp ≈ lp_grad "log_density values should match"
@assert length(grad) == 1 "gradient should have 1 element"

println("[$label] All checks passed!")
26 changes: 26 additions & 0 deletions stan/math/rev/core/accumulate_adjoints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core/var.hpp>

Expand Down Expand Up @@ -35,6 +36,9 @@ inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args);

inline double* accumulate_adjoints(double* dest);

template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
inline double* accumulate_adjoints(double* dest, Tuple&& x, Pargs&&... args);

/**
* Accumulate adjoints from x into storage pointed to by dest,
* increment the adjoint storage pointer,
Expand Down Expand Up @@ -147,6 +151,28 @@ inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args) {
*/
inline double* accumulate_adjoints(double* dest) { return dest; }

/**
* Accumulate adjoints from a tuple into storage pointed to by dest
* by unpacking the tuple and recursively processing each element.
*
* @tparam Tuple A std::tuple type
* @tparam Pargs Types of remaining arguments
* @param dest Pointer to where adjoints are to be accumulated
* @param x A tuple potentially containing vars
* @param args Further args to accumulate over
* @return Final position of adjoint storage pointer
*/
template <typename Tuple, require_tuple_t<Tuple>*, typename... Pargs>
inline double* accumulate_adjoints(double* dest, Tuple&& x, Pargs&&... args) {
dest = stan::math::apply(
[dest](auto&&... inner_args) {
return accumulate_adjoints(
dest, std::forward<decltype(inner_args)>(inner_args)...);
},
std::forward<Tuple>(x));
return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
}

} // namespace math
} // namespace stan

Expand Down
19 changes: 19 additions & 0 deletions stan/math/rev/core/deep_copy_vars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_REV_CORE_DEEP_COPY_VARS_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core/var.hpp>

Expand Down Expand Up @@ -81,6 +82,24 @@ inline auto deep_copy_vars(EigT&& arg) {
.eval();
}

/**
* Deep copy vars in a tuple, reallocating new varis for var elements
* and forwarding non-var elements unchanged.
*
* @tparam Tuple A std::tuple type
* @param arg A tuple potentially containing vars
* @return A new tuple with deep-copied vars
*/
template <typename Tuple, require_tuple_t<Tuple>* = nullptr>
inline auto deep_copy_vars(Tuple&& arg) {
return stan::math::apply(
[](auto&&... args) {
return std::make_tuple(
deep_copy_vars(std::forward<decltype(args)>(args))...);
},
std::forward<Tuple>(arg));
}

} // namespace math
} // namespace stan

Expand Down
26 changes: 26 additions & 0 deletions stan/math/rev/core/save_varis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core/var.hpp>

Expand Down Expand Up @@ -35,6 +36,9 @@ inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args);

inline vari** save_varis(vari** dest);

template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
inline vari** save_varis(vari** dest, Tuple&& x, Pargs&&... args);

/**
* Save the vari pointer in x into the memory pointed to by dest,
* increment the dest storage pointer,
Expand Down Expand Up @@ -143,6 +147,28 @@ inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args) {
*/
inline vari** save_varis(vari** dest) { return dest; }

/**
* Save the vari pointers in a tuple into the memory pointed to by dest
* by unpacking the tuple and recursively processing each element.
*
* @tparam Tuple A std::tuple type
* @tparam Pargs Types of remaining arguments
* @param[in, out] dest Pointer to where vari pointers are saved
* @param[in] x A tuple potentially containing vars
* @param[in] args Additional arguments to have their varis saved
* @return Final position of dest pointer
*/
template <typename Tuple, require_tuple_t<Tuple>*, typename... Pargs>
inline vari** save_varis(vari** dest, Tuple&& x, Pargs&&... args) {
dest = stan::math::apply(
[dest](auto&&... inner_args) {
return save_varis(dest,
std::forward<decltype(inner_args)>(inner_args)...);
},
std::forward<Tuple>(x));
return save_varis(dest, std::forward<Pargs>(args)...);
}

} // namespace math
} // namespace stan

Expand Down
38 changes: 38 additions & 0 deletions test/unit/math/rev/core/accumulate_adjoints_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,44 @@ TEST_F(AgradRev, Rev_accumulate_adjoints_std_vector_eigen_matrix_var_arg) {
stan::math::recover_memory();
}

TEST_F(AgradRev, Rev_accumulate_adjoints_tuple_var_int_arg) {
using stan::math::var;
using stan::math::vari;
var a(5.0);
a.vi_->adj_ = 3.0;
int b = 7;
auto arg = std::make_tuple(a, b);

Eigen::VectorXd storage = Eigen::VectorXd::Zero(1000);
double* ptr = stan::math::accumulate_adjoints(storage.data(), arg);

EXPECT_FLOAT_EQ(storage(0), 3.0);
for (int i = 1; i < storage.size(); ++i)
EXPECT_FLOAT_EQ(storage(i), 0.0);
EXPECT_EQ(ptr, storage.data() + 1);
stan::math::recover_memory();
}

TEST_F(AgradRev, Rev_accumulate_adjoints_tuple_var_var_arg) {
using stan::math::var;
using stan::math::vari;
var a(5.0);
a.vi_->adj_ = 3.0;
var b(7.0);
b.vi_->adj_ = 4.0;
auto arg = std::make_tuple(a, b);

Eigen::VectorXd storage = Eigen::VectorXd::Zero(1000);
double* ptr = stan::math::accumulate_adjoints(storage.data(), arg);

EXPECT_FLOAT_EQ(storage(0), 3.0);
EXPECT_FLOAT_EQ(storage(1), 4.0);
for (int i = 2; i < storage.size(); ++i)
EXPECT_FLOAT_EQ(storage(i), 0.0);
EXPECT_EQ(ptr, storage.data() + 2);
stan::math::recover_memory();
}

TEST_F(AgradRev, Rev_accumulate_adjoints_sum) {
using stan::math::var;
using stan::math::vari;
Expand Down
37 changes: 37 additions & 0 deletions test/unit/math/rev/core/deep_copy_vars_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,43 @@ TEST_F(AgradRev, Rev_deep_copy_vars_std_vector_eigen_row_vector_var_arg) {
}
}

TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_int_arg) {
var a(3.0);
int b = 5;
auto arg = std::make_tuple(a, b);

auto out = stan::math::deep_copy_vars(arg);

EXPECT_EQ(std::get<0>(out).val(), a.val());
EXPECT_NE(std::get<0>(out).vi_, a.vi_);
EXPECT_EQ(std::get<1>(out), b);
}

TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_double_arg) {
var a(3.0);
double b = 5.0;
auto arg = std::make_tuple(a, b);

auto out = stan::math::deep_copy_vars(arg);

EXPECT_EQ(std::get<0>(out).val(), a.val());
EXPECT_NE(std::get<0>(out).vi_, a.vi_);
EXPECT_EQ(std::get<1>(out), b);
}

TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_var_arg) {
var a(3.0);
var b(7.0);
auto arg = std::make_tuple(a, b);

auto out = stan::math::deep_copy_vars(arg);

EXPECT_EQ(std::get<0>(out).val(), a.val());
EXPECT_NE(std::get<0>(out).vi_, a.vi_);
EXPECT_EQ(std::get<1>(out).val(), b.val());
EXPECT_NE(std::get<1>(out).vi_, b.vi_);
}

TEST_F(AgradRev, Rev_deep_copy_vars_std_vector_eigen_matrix_var_arg) {
Eigen::Matrix<var, Eigen::Dynamic, Eigen::Dynamic> arg_(5, 3);
std::vector<Eigen::Matrix<var, Eigen::Dynamic, Eigen::Dynamic>> arg(2, arg_);
Expand Down
33 changes: 33 additions & 0 deletions test/unit/math/rev/core/save_varis_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,39 @@ TEST_F(AgradRev, Rev_save_varis_std_vector_eigen_matrix_var_arg) {
EXPECT_EQ(ptr, storage.data() + num_vars);
}

TEST_F(AgradRev, Rev_save_varis_tuple_var_int_arg) {
var a(5.0);
int b = 3;
auto arg = std::make_tuple(a, b);

std::vector<vari*> storage(1000, nullptr);
vari** ptr = stan::math::save_varis(storage.data(), arg);

size_t num_vars = stan::math::count_vars(arg);
EXPECT_EQ(num_vars, 1);
EXPECT_EQ(storage[0], a.vi_);
for (int i = num_vars; i < storage.size(); ++i)
EXPECT_EQ(storage[i], nullptr);
EXPECT_EQ(ptr, storage.data() + num_vars);
}

TEST_F(AgradRev, Rev_save_varis_tuple_var_var_arg) {
var a(5.0);
var b(7.0);
auto arg = std::make_tuple(a, b);

std::vector<vari*> storage(1000, nullptr);
vari** ptr = stan::math::save_varis(storage.data(), arg);

size_t num_vars = stan::math::count_vars(arg);
EXPECT_EQ(num_vars, 2);
EXPECT_EQ(storage[0], a.vi_);
EXPECT_EQ(storage[1], b.vi_);
for (int i = num_vars; i < storage.size(); ++i)
EXPECT_EQ(storage[i], nullptr);
EXPECT_EQ(ptr, storage.data() + num_vars);
}

TEST_F(AgradRev, Rev_save_varis_sum) {
int arg1 = 1;
double arg2 = 1.0;
Expand Down