diff --git a/stan/math/prim/prob/gamma_lccdf.hpp b/stan/math/prim/prob/gamma_lccdf.hpp index a670cefcecf..c03204dbb6f 100644 --- a/stan/math/prim/prob/gamma_lccdf.hpp +++ b/stan/math/prim/prob/gamma_lccdf.hpp @@ -14,20 +14,92 @@ #include #include #include -#include +#include +#include #include #include namespace stan { namespace math { +namespace internal { +template +struct Q_eval { + T log_Q{0.0}; + T dlogQ_dalpha{0.0}; + bool ok{false}; +}; + +/** + * Computes log q and d(log q) / d(alpha) using continued fraction. + */ +template +static inline Q_eval eval_q_cf(const T& alpha, const T& beta_y) { + Q_eval out; + if constexpr (!any_fvar && is_autodiff_v) { + auto log_q_result + = log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y)); + out.log_Q = log_q_result.log_q; + out.dlogQ_dalpha = log_q_result.dlog_q_da; + } else { + out.log_Q = internal::log_q_gamma_cf(alpha, beta_y); + if constexpr (is_autodiff_v) { + if constexpr (!partials_fvar) { + out.dlogQ_dalpha + = grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha)) + / exp(out.log_Q); + } else { + T alpha_unit = alpha; + alpha_unit.d_ = 1; + T beta_y_unit = beta_y; + beta_y_unit.d_ = 0; + T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit); + out.dlogQ_dalpha = log_Q_fvar.d_; + } + } + } + + out.ok = std::isfinite(value_of_rec(out.log_Q)); + return out; +} + +/** + * Computes log q and d(log q) / d(alpha) using log1m. + */ +template +static inline Q_eval eval_q_log1m(const T& alpha, const T& beta_y) { + Q_eval out; + out.log_Q = log1m(gamma_p(alpha, beta_y)); + + if (!std::isfinite(value_of_rec(out.log_Q))) { + out.ok = false; + return out; + } + + if constexpr (is_autodiff_v) { + if constexpr (partials_fvar) { + T alpha_unit = alpha; + alpha_unit.d_ = 1; + T beta_unit = beta_y; + beta_unit.d_ = 0; + T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit)); + out.dlogQ_dalpha = log_Q_fvar.d_; + } else { + out.dlogQ_dalpha + = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q); + } + } + + out.ok = true; + return out; +} +} // namespace internal template inline return_type_t gamma_lccdf( const T_y& y, const T_shape& alpha, const T_inv_scale& beta) { - using T_partials_return = partials_return_t; using std::exp; using std::log; - using std::pow; + using T_partials_return = partials_return_t; using T_y_ref = ref_type_t; using T_alpha_ref = ref_type_t; using T_beta_ref = ref_type_t; @@ -53,14 +125,10 @@ inline return_type_t gamma_lccdf( scalar_seq_view beta_vec(beta_ref); size_t N = max_size(y, alpha, beta); - // Explicit return for extreme values - // The gradients are technically ill-defined, but treated as zero - for (size_t i = 0; i < stan::math::size(y); i++) { - if (y_vec.val(i) == 0) { - // LCCDF(0) = log(P(Y > 0)) = log(1) = 0 - return ops_partials.build(0.0); - } - } + constexpr bool any_fvar = is_fvar>::value + || is_fvar>::value + || is_fvar>::value; + constexpr bool partials_fvar = is_fvar::value; for (size_t n = 0; n < N; n++) { // Explicit results for extreme values @@ -79,15 +147,37 @@ inline return_type_t gamma_lccdf( const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl); const T_partials_return log_Qn = log(Qn); - P += log_Qn; + const bool use_continued_fraction = beta_y > alpha_dbl + 1.0; + internal::Q_eval result; + if (use_continued_fraction) { + result = internal::eval_q_cf(alpha_dbl, beta_y); + } else { + result + = internal::eval_q_log1m( + alpha_dbl, beta_y); + + if (!result.ok && beta_y > 0.0) { + // Fallback to continued fraction if log1m fails + result = internal::eval_q_cf(alpha_dbl, beta_y); + } + } + if (!result.ok) { + return ops_partials.build(negative_infinity()); + } + + P += result.log_Q; + + if constexpr (is_autodiff_v || is_autodiff_v) { + const T_partials_return log_y = log(y_dbl); + const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y); + + const T_partials_return log_pdf = alpha_dbl * log(beta_dbl) + - lgamma(alpha_dbl) + alpha_minus_one + - beta_y; - if constexpr (is_any_autodiff_v) { - const T_partials_return log_y_dbl = log(y_dbl); - const T_partials_return log_beta_dbl = log(beta_dbl); - const T_partials_return log_pdf - = alpha_dbl * log_beta_dbl - lgamma(alpha_dbl) - + (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl; - const T_partials_return common_term = exp(log_pdf - log_Qn); + const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q if constexpr (is_autodiff_v) { // d/dy log(1-F(y)) = -f(y)/(1-F(y)) @@ -100,12 +190,7 @@ inline return_type_t gamma_lccdf( } if constexpr (is_autodiff_v) { - const T_partials_return digamma_val = digamma(alpha_dbl); - const T_partials_return gamma_val = tgamma(alpha_dbl); - // d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y)) - partials<1>(ops_partials)[n] - += grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val) - / Qn; + partials<1>(ops_partials)[n] += result.dlogQ_dalpha; } } return ops_partials.build(P);