Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 110 additions & 25 deletions stan/math/prim/prob/gamma_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,92 @@
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/tgamma.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <cmath>

namespace stan {
namespace math {
namespace internal {
template <typename T>
struct Q_eval {
T log_Q{0.0};
T dlogQ_dalpha{0.0};
bool ok{false};
};

/**
* Computes log q and d(log q) / d(alpha) using continued fraction.
*/
template <typename T, typename T_shape, bool any_fvar, bool partials_fvar>
static inline Q_eval<T> eval_q_cf(const T& alpha, const T& beta_y) {
Q_eval<T> out;
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
auto log_q_result
= log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
out.log_Q = log_q_result.log_q;
out.dlogQ_dalpha = log_q_result.dlog_q_da;
} else {
out.log_Q = internal::log_q_gamma_cf(alpha, beta_y);
if constexpr (is_autodiff_v<T_shape>) {
if constexpr (!partials_fvar) {
out.dlogQ_dalpha
= grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha))
/ exp(out.log_Q);
} else {
T alpha_unit = alpha;
alpha_unit.d_ = 1;
T beta_y_unit = beta_y;
beta_y_unit.d_ = 0;
T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
out.dlogQ_dalpha = log_Q_fvar.d_;
}
}
}

out.ok = std::isfinite(value_of_rec(out.log_Q));
return out;
}

/**
* Computes log q and d(log q) / d(alpha) using log1m.
*/
template <typename T, typename T_shape, bool partials_fvar>
static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) {
Q_eval<T> out;
out.log_Q = log1m(gamma_p(alpha, beta_y));

if (!std::isfinite(value_of_rec(out.log_Q))) {
out.ok = false;
return out;
}

if constexpr (is_autodiff_v<T_shape>) {
if constexpr (partials_fvar) {
T alpha_unit = alpha;
alpha_unit.d_ = 1;
T beta_unit = beta_y;
beta_unit.d_ = 0;
T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
out.dlogQ_dalpha = log_Q_fvar.d_;
} else {
out.dlogQ_dalpha
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q);
}
}

out.ok = true;
return out;
}
} // namespace internal

template <typename T_y, typename T_shape, typename T_inv_scale>
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
using std::exp;
using std::log;
using std::pow;
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
using T_y_ref = ref_type_t<T_y>;
using T_alpha_ref = ref_type_t<T_shape>;
using T_beta_ref = ref_type_t<T_inv_scale>;
Expand All @@ -53,14 +125,10 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
size_t N = max_size(y, alpha, beta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < stan::math::size(y); i++) {
if (y_vec.val(i) == 0) {
// LCCDF(0) = log(P(Y > 0)) = log(1) = 0
return ops_partials.build(0.0);
}
}
constexpr bool any_fvar = is_fvar<scalar_type_t<T_y>>::value
|| is_fvar<scalar_type_t<T_shape>>::value
|| is_fvar<scalar_type_t<T_inv_scale>>::value;
constexpr bool partials_fvar = is_fvar<T_partials_return>::value;

for (size_t n = 0; n < N; n++) {
// Explicit results for extreme values
Expand All @@ -79,15 +147,37 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl);
const T_partials_return log_Qn = log(Qn);

P += log_Qn;
const bool use_continued_fraction = beta_y > alpha_dbl + 1.0;
internal::Q_eval<T_partials_return> result;
if (use_continued_fraction) {
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
partials_fvar>(alpha_dbl, beta_y);
} else {
result
= internal::eval_q_log1m<T_partials_return, T_shape, partials_fvar>(
alpha_dbl, beta_y);

if (!result.ok && beta_y > 0.0) {
// Fallback to continued fraction if log1m fails
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
partials_fvar>(alpha_dbl, beta_y);
}
}
if (!result.ok) {
return ops_partials.build(negative_infinity());
}

P += result.log_Q;

if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
const T_partials_return log_y = log(y_dbl);
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);

const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
- lgamma(alpha_dbl) + alpha_minus_one
- beta_y;

if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
const T_partials_return log_y_dbl = log(y_dbl);
const T_partials_return log_beta_dbl = log(beta_dbl);
const T_partials_return log_pdf
= alpha_dbl * log_beta_dbl - lgamma(alpha_dbl)
+ (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl;
const T_partials_return common_term = exp(log_pdf - log_Qn);
const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q

if constexpr (is_autodiff_v<T_y>) {
// d/dy log(1-F(y)) = -f(y)/(1-F(y))
Expand All @@ -100,12 +190,7 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
}

if constexpr (is_autodiff_v<T_shape>) {
const T_partials_return digamma_val = digamma(alpha_dbl);
const T_partials_return gamma_val = tgamma(alpha_dbl);
// d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
partials<1>(ops_partials)[n]
+= grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
/ Qn;
partials<1>(ops_partials)[n] += result.dlogQ_dalpha;
}
}
return ops_partials.build(P);
Expand Down
Loading