diff --git a/mwe/Project.toml b/mwe/Project.toml new file mode 100644 index 00000000000..bdf49f9a5ef --- /dev/null +++ b/mwe/Project.toml @@ -0,0 +1,2 @@ +[deps] +BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a" diff --git a/mwe/mwe.jl b/mwe/mwe.jl new file mode 100644 index 00000000000..b2ab7ea189d --- /dev/null +++ b/mwe/mwe.jl @@ -0,0 +1,54 @@ +using BridgeStan + +stan_code = """ +functions { + real partial_sum(array[] real slice, int start, int end, + tuple(real, int) params) { + real mu = params.1; + int K = params.2; + real lp = 0; + for (i in 1:size(slice)) { + lp += normal_lpdf(slice[i] | mu, K); + } + return lp; + } +} +data { + int N; + int K; + array[N] real y; +} +parameters { + real mu; +} +model { + mu ~ normal(0, 10); + target += reduce_sum(partial_sum, y, 1, (mu, K)); +} +""" + +stan_math = ENV["MWE_RUN_DIR"] +label = ENV["MWE_LABEL"] + +workdir = mktempdir() +stan_file = joinpath(workdir, "mwe.stan") +write(stan_file, stan_code) + +lib = compile_model(stan_file; make_args=["MATH=$stan_math/", "STAN_THREADS=true"]) + +data = """{"N": 5, "K": 2, "y": [1.0, 2.0, 3.0, 4.0, 5.0]}""" +sm = StanModel(lib, data) + +params = [3.0] # mu (unconstrained) +lp = log_density(sm, params) +lp_grad, grad = log_density_gradient(sm, params) + +println("[$label] log_density = $lp") +println("[$label] gradient = $grad") + +@assert isfinite(lp) "log_density should be finite" +@assert all(isfinite, grad) "gradient should be finite" +@assert lp ≈ lp_grad "log_density values should match" +@assert length(grad) == 1 "gradient should have 1 element" + +println("[$label] All checks passed!") diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index e5b27354ebd..f046bc66941 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP #include +#include #include #include @@ -35,6 +36,9 @@ inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args); inline double* accumulate_adjoints(double* dest); +template * = nullptr, typename... Pargs> +inline double* accumulate_adjoints(double* dest, Tuple&& x, Pargs&&... args); + /** * Accumulate adjoints from x into storage pointed to by dest, * increment the adjoint storage pointer, @@ -147,6 +151,28 @@ inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args) { */ inline double* accumulate_adjoints(double* dest) { return dest; } +/** + * Accumulate adjoints from a tuple into storage pointed to by dest + * by unpacking the tuple and recursively processing each element. + * + * @tparam Tuple A std::tuple type + * @tparam Pargs Types of remaining arguments + * @param dest Pointer to where adjoints are to be accumulated + * @param x A tuple potentially containing vars + * @param args Further args to accumulate over + * @return Final position of adjoint storage pointer + */ +template *, typename... Pargs> +inline double* accumulate_adjoints(double* dest, Tuple&& x, Pargs&&... args) { + dest = stan::math::apply( + [dest](auto&&... inner_args) { + return accumulate_adjoints( + dest, std::forward(inner_args)...); + }, + std::forward(x)); + return accumulate_adjoints(dest, std::forward(args)...); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 06561d1a9e0..50e8c11f642 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_CORE_DEEP_COPY_VARS_HPP #include +#include #include #include @@ -81,6 +82,24 @@ inline auto deep_copy_vars(EigT&& arg) { .eval(); } +/** + * Deep copy vars in a tuple, reallocating new varis for var elements + * and forwarding non-var elements unchanged. + * + * @tparam Tuple A std::tuple type + * @param arg A tuple potentially containing vars + * @return A new tuple with deep-copied vars + */ +template * = nullptr> +inline auto deep_copy_vars(Tuple&& arg) { + return stan::math::apply( + [](auto&&... args) { + return std::make_tuple( + deep_copy_vars(std::forward(args))...); + }, + std::forward(arg)); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index c53a5390539..3293fe4d73a 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -35,6 +36,9 @@ inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args); inline vari** save_varis(vari** dest); +template * = nullptr, typename... Pargs> +inline vari** save_varis(vari** dest, Tuple&& x, Pargs&&... args); + /** * Save the vari pointer in x into the memory pointed to by dest, * increment the dest storage pointer, @@ -143,6 +147,28 @@ inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args) { */ inline vari** save_varis(vari** dest) { return dest; } +/** + * Save the vari pointers in a tuple into the memory pointed to by dest + * by unpacking the tuple and recursively processing each element. + * + * @tparam Tuple A std::tuple type + * @tparam Pargs Types of remaining arguments + * @param[in, out] dest Pointer to where vari pointers are saved + * @param[in] x A tuple potentially containing vars + * @param[in] args Additional arguments to have their varis saved + * @return Final position of dest pointer + */ +template *, typename... Pargs> +inline vari** save_varis(vari** dest, Tuple&& x, Pargs&&... args) { + dest = stan::math::apply( + [dest](auto&&... inner_args) { + return save_varis(dest, + std::forward(inner_args)...); + }, + std::forward(x)); + return save_varis(dest, std::forward(args)...); +} + } // namespace math } // namespace stan diff --git a/test/unit/math/rev/core/accumulate_adjoints_test.cpp b/test/unit/math/rev/core/accumulate_adjoints_test.cpp index e8017993f73..52faea0d7d7 100644 --- a/test/unit/math/rev/core/accumulate_adjoints_test.cpp +++ b/test/unit/math/rev/core/accumulate_adjoints_test.cpp @@ -381,6 +381,44 @@ TEST_F(AgradRev, Rev_accumulate_adjoints_std_vector_eigen_matrix_var_arg) { stan::math::recover_memory(); } +TEST_F(AgradRev, Rev_accumulate_adjoints_tuple_var_int_arg) { + using stan::math::var; + using stan::math::vari; + var a(5.0); + a.vi_->adj_ = 3.0; + int b = 7; + auto arg = std::make_tuple(a, b); + + Eigen::VectorXd storage = Eigen::VectorXd::Zero(1000); + double* ptr = stan::math::accumulate_adjoints(storage.data(), arg); + + EXPECT_FLOAT_EQ(storage(0), 3.0); + for (int i = 1; i < storage.size(); ++i) + EXPECT_FLOAT_EQ(storage(i), 0.0); + EXPECT_EQ(ptr, storage.data() + 1); + stan::math::recover_memory(); +} + +TEST_F(AgradRev, Rev_accumulate_adjoints_tuple_var_var_arg) { + using stan::math::var; + using stan::math::vari; + var a(5.0); + a.vi_->adj_ = 3.0; + var b(7.0); + b.vi_->adj_ = 4.0; + auto arg = std::make_tuple(a, b); + + Eigen::VectorXd storage = Eigen::VectorXd::Zero(1000); + double* ptr = stan::math::accumulate_adjoints(storage.data(), arg); + + EXPECT_FLOAT_EQ(storage(0), 3.0); + EXPECT_FLOAT_EQ(storage(1), 4.0); + for (int i = 2; i < storage.size(); ++i) + EXPECT_FLOAT_EQ(storage(i), 0.0); + EXPECT_EQ(ptr, storage.data() + 2); + stan::math::recover_memory(); +} + TEST_F(AgradRev, Rev_accumulate_adjoints_sum) { using stan::math::var; using stan::math::vari; diff --git a/test/unit/math/rev/core/deep_copy_vars_test.cpp b/test/unit/math/rev/core/deep_copy_vars_test.cpp index 99f78c9fb9c..6cba3a8134b 100644 --- a/test/unit/math/rev/core/deep_copy_vars_test.cpp +++ b/test/unit/math/rev/core/deep_copy_vars_test.cpp @@ -283,6 +283,43 @@ TEST_F(AgradRev, Rev_deep_copy_vars_std_vector_eigen_row_vector_var_arg) { } } +TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_int_arg) { + var a(3.0); + int b = 5; + auto arg = std::make_tuple(a, b); + + auto out = stan::math::deep_copy_vars(arg); + + EXPECT_EQ(std::get<0>(out).val(), a.val()); + EXPECT_NE(std::get<0>(out).vi_, a.vi_); + EXPECT_EQ(std::get<1>(out), b); +} + +TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_double_arg) { + var a(3.0); + double b = 5.0; + auto arg = std::make_tuple(a, b); + + auto out = stan::math::deep_copy_vars(arg); + + EXPECT_EQ(std::get<0>(out).val(), a.val()); + EXPECT_NE(std::get<0>(out).vi_, a.vi_); + EXPECT_EQ(std::get<1>(out), b); +} + +TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_var_arg) { + var a(3.0); + var b(7.0); + auto arg = std::make_tuple(a, b); + + auto out = stan::math::deep_copy_vars(arg); + + EXPECT_EQ(std::get<0>(out).val(), a.val()); + EXPECT_NE(std::get<0>(out).vi_, a.vi_); + EXPECT_EQ(std::get<1>(out).val(), b.val()); + EXPECT_NE(std::get<1>(out).vi_, b.vi_); +} + TEST_F(AgradRev, Rev_deep_copy_vars_std_vector_eigen_matrix_var_arg) { Eigen::Matrix arg_(5, 3); std::vector> arg(2, arg_); diff --git a/test/unit/math/rev/core/save_varis_test.cpp b/test/unit/math/rev/core/save_varis_test.cpp index 5f8a69ce0d9..110cc7d7d3d 100644 --- a/test/unit/math/rev/core/save_varis_test.cpp +++ b/test/unit/math/rev/core/save_varis_test.cpp @@ -332,6 +332,39 @@ TEST_F(AgradRev, Rev_save_varis_std_vector_eigen_matrix_var_arg) { EXPECT_EQ(ptr, storage.data() + num_vars); } +TEST_F(AgradRev, Rev_save_varis_tuple_var_int_arg) { + var a(5.0); + int b = 3; + auto arg = std::make_tuple(a, b); + + std::vector storage(1000, nullptr); + vari** ptr = stan::math::save_varis(storage.data(), arg); + + size_t num_vars = stan::math::count_vars(arg); + EXPECT_EQ(num_vars, 1); + EXPECT_EQ(storage[0], a.vi_); + for (int i = num_vars; i < storage.size(); ++i) + EXPECT_EQ(storage[i], nullptr); + EXPECT_EQ(ptr, storage.data() + num_vars); +} + +TEST_F(AgradRev, Rev_save_varis_tuple_var_var_arg) { + var a(5.0); + var b(7.0); + auto arg = std::make_tuple(a, b); + + std::vector storage(1000, nullptr); + vari** ptr = stan::math::save_varis(storage.data(), arg); + + size_t num_vars = stan::math::count_vars(arg); + EXPECT_EQ(num_vars, 2); + EXPECT_EQ(storage[0], a.vi_); + EXPECT_EQ(storage[1], b.vi_); + for (int i = num_vars; i < storage.size(); ++i) + EXPECT_EQ(storage[i], nullptr); + EXPECT_EQ(ptr, storage.data() + num_vars); +} + TEST_F(AgradRev, Rev_save_varis_sum) { int arg1 = 1; double arg2 = 1.0;