-
-
Notifications
You must be signed in to change notification settings - Fork 198
Use boost beta, simplify overloads #3212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
e5e2f13
544a5e8
2c81206
87f0309
f691622
0cbdf91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -42,32 +42,29 @@ namespace math { | |||||
| \end{cases} | ||||||
| \f] | ||||||
| * | ||||||
| * @tparam T inner type of the fvar | ||||||
| * @param x1 First value | ||||||
| * @param x2 Second value | ||||||
| * @tparam Ta Type of first scalar argument | ||||||
| * @tparam Tb Type of second scalar argument | ||||||
| * @param a First value | ||||||
| * @param b Second value | ||||||
| * @return Fvar with result beta function of arguments and gradients. | ||||||
| */ | ||||||
| template <typename T> | ||||||
| inline fvar<T> beta(const fvar<T>& x1, const fvar<T>& x2) { | ||||||
| const T beta_ab = beta(x1.val_, x2.val_); | ||||||
| return fvar<T>(beta_ab, | ||||||
| beta_ab | ||||||
| * (x1.d_ * digamma(x1.val_) + x2.d_ * digamma(x2.val_) | ||||||
| - (x1.d_ + x2.d_) * digamma(x1.val_ + x2.val_))); | ||||||
| } | ||||||
|
|
||||||
| template <typename T> | ||||||
| inline fvar<T> beta(double x1, const fvar<T>& x2) { | ||||||
| const T beta_ab = beta(x1, x2.val_); | ||||||
| return fvar<T>(beta_ab, | ||||||
| x2.d_ * (digamma(x2.val_) - digamma(x1 + x2.val_)) * beta_ab); | ||||||
| } | ||||||
|
|
||||||
| template <typename T> | ||||||
| inline fvar<T> beta(const fvar<T>& x1, double x2) { | ||||||
| const T beta_ab = beta(x1.val_, x2); | ||||||
| return fvar<T>(beta_ab, | ||||||
| x1.d_ * (digamma(x1.val_) - digamma(x1.val_ + x2)) * beta_ab); | ||||||
| template <typename Ta, typename Tb, | ||||||
| typename FvarInnerT = partials_return_t<Ta, Tb>, | ||||||
| require_return_type_t<is_fvar, Ta, Tb>* = nullptr, | ||||||
| require_all_stan_scalar_t<Ta, Tb>* = nullptr> | ||||||
| inline fvar<FvarInnerT> beta(const Ta& a, const Tb& b) { | ||||||
| const auto& a_val = value_of(a); | ||||||
| const auto& b_val = value_of(b); | ||||||
| const FvarInnerT beta_val = beta(a_val, b_val); | ||||||
| const FvarInnerT digamma_ab = digamma(a_val + b_val); | ||||||
| FvarInnerT beta_d(0); | ||||||
| if constexpr (!is_constant<Ta>::value) { | ||||||
| beta_d += (digamma(a_val) - digamma_ab) * beta_val * a.d_; | ||||||
| } | ||||||
| if constexpr (!is_constant<Tb>::value) { | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| beta_d += (digamma(b_val) - digamma_ab) * beta_val * b.d_; | ||||||
| } | ||||||
| return fvar<FvarInnerT>(beta_val, beta_d); | ||||||
| } | ||||||
|
|
||||||
| } // namespace math | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,9 @@ | |
| #define STAN_MATH_PRIM_FUN_BETA_HPP | ||
|
|
||
| #include <stan/math/prim/meta.hpp> | ||
| #include <stan/math/prim/fun/exp.hpp> | ||
| #include <stan/math/prim/fun/lgamma.hpp> | ||
| #include <stan/math/prim/fun/boost_policy.hpp> | ||
| #include <stan/math/prim/functor/apply_scalar_binary.hpp> | ||
| #include <cmath> | ||
| #include <boost/math/special_functions/beta.hpp> | ||
|
|
||
| namespace stan { | ||
| namespace math { | ||
|
|
@@ -51,8 +50,7 @@ namespace math { | |
| */ | ||
| template <typename T1, typename T2, require_all_arithmetic_t<T1, T2>* = nullptr> | ||
| inline return_type_t<T1, T2> beta(const T1 a, const T2 b) { | ||
| using std::exp; | ||
| return exp(lgamma(a) + lgamma(b) - lgamma(a + b)); | ||
| return boost::math::beta(a, b, boost_policy_t<>()); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -65,8 +63,11 @@ inline return_type_t<T1, T2> beta(const T1 a, const T2 b) { | |
| * @param b Second input | ||
| * @return Beta function applied to the two inputs. | ||
| */ | ||
| template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr, | ||
| require_all_not_var_matrix_t<T1, T2>* = nullptr> | ||
| template < | ||
| typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr, | ||
| require_t<math::disjunction< | ||
| is_arithmetic<return_type_t<T1, T2>>, is_fvar<return_type_t<T1, T2>>, | ||
| is_std_vector<T1>, is_std_vector<T2>>>* = nullptr> | ||
|
Comment on lines
+66
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Little confused. Wouldn't |
||
| inline auto beta(T1&& a, T2&& b) { | ||
| return apply_scalar_binary( | ||
| [](auto&& c, auto&& d) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -34,197 +34,42 @@ namespace math { | |||||||||||||
| * @param b var Argument | ||||||||||||||
| * @return Result of beta function | ||||||||||||||
| */ | ||||||||||||||
| inline var beta(const var& a, const var& b) { | ||||||||||||||
| double digamma_ab = digamma(a.val() + b.val()); | ||||||||||||||
| double digamma_a = digamma(a.val()) - digamma_ab; | ||||||||||||||
| double digamma_b = digamma(b.val()) - digamma_ab; | ||||||||||||||
| return make_callback_var(beta(a.val(), b.val()), | ||||||||||||||
| [a, b, digamma_a, digamma_b](auto& vi) mutable { | ||||||||||||||
| const double adj_val = vi.adj() * vi.val(); | ||||||||||||||
| a.adj() += adj_val * digamma_a; | ||||||||||||||
| b.adj() += adj_val * digamma_b; | ||||||||||||||
| }); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * Returns the beta function and gradient for first var input. | ||||||||||||||
| * | ||||||||||||||
| \f[ | ||||||||||||||
| \mathrm{beta}(a,b) = \left(B\left(a,b\right)\right) | ||||||||||||||
| \f] | ||||||||||||||
| template <typename T1, typename T2, | ||||||||||||||
| require_all_not_std_vector_t<T1, T2>* = nullptr, | ||||||||||||||
| require_return_type_t<is_var, T1, T2>* = nullptr> | ||||||||||||||
| inline auto beta(const T1& a, const T2& b) { | ||||||||||||||
| arena_t<ref_type_t<T1>> arena_a = a; | ||||||||||||||
| arena_t<ref_type_t<T2>> arena_b = b; | ||||||||||||||
|
Comment on lines
+40
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| \f[ | ||||||||||||||
| \frac{\partial }{\partial a} = \left(\psi^{\left(0\right)}\left(a\right) | ||||||||||||||
| - \psi^{\left(0\right)} | ||||||||||||||
| \left(a + b\right)\right) | ||||||||||||||
| * \mathrm{beta}(a,b) | ||||||||||||||
| \f] | ||||||||||||||
| * | ||||||||||||||
| * @param a var Argument | ||||||||||||||
| * @param b double Argument | ||||||||||||||
| * @return Result of beta function | ||||||||||||||
| */ | ||||||||||||||
| inline var beta(const var& a, double b) { | ||||||||||||||
| auto digamma_ab = digamma(a.val()) - digamma(a.val() + b); | ||||||||||||||
| return make_callback_var(beta(a.val(), b), [a, digamma_ab](auto& vi) mutable { | ||||||||||||||
| a.adj() += vi.adj() * digamma_ab * vi.val(); | ||||||||||||||
| }); | ||||||||||||||
| } | ||||||||||||||
| const auto& beta_val = beta(value_of(arena_a), value_of(arena_b)); | ||||||||||||||
| using return_type_t = return_var_matrix_t<decltype(beta_val), T1, T2>; | ||||||||||||||
| arena_t<return_type_t> res(beta_val); | ||||||||||||||
|
Comment on lines
+44
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * Returns the beta function and gradient for second var input. | ||||||||||||||
| * | ||||||||||||||
| \f[ | ||||||||||||||
| \mathrm{beta}(a,b) = \left(B\left(a,b\right)\right) | ||||||||||||||
| \f] | ||||||||||||||
| reverse_pass_callback([arena_a, arena_b, res]() mutable { | ||||||||||||||
| auto&& a_array = as_array_or_scalar(arena_a); | ||||||||||||||
| auto&& b_array = as_array_or_scalar(arena_b); | ||||||||||||||
| const auto& res_array = as_array_or_scalar(res); | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use |
||||||||||||||
| const auto& digamma_ab = digamma(value_of(a_array) + value_of(b_array)); | ||||||||||||||
| const auto& adj_val = res_array.adj() * res_array.val(); | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This produces a new expression objet so we don't need a
Suggested change
Anywhere that we will produce a new Eigen expression it should just be an object. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for several places here |
||||||||||||||
|
|
||||||||||||||
| \f[ | ||||||||||||||
| \frac{\partial }{\partial b} = \left(\psi^{\left(0\right)}\left(b\right) | ||||||||||||||
| - \psi^{\left(0\right)} | ||||||||||||||
| \left(a + b\right)\right) | ||||||||||||||
| * \mathrm{beta}(a,b) | ||||||||||||||
| \f] | ||||||||||||||
| * | ||||||||||||||
| * @param a double Argument | ||||||||||||||
| * @param b var Argument | ||||||||||||||
| * @return Result of beta function | ||||||||||||||
| */ | ||||||||||||||
| inline var beta(double a, const var& b) { | ||||||||||||||
| auto beta_val = beta(a, b.val()); | ||||||||||||||
| auto digamma_ab = (digamma(b.val()) - digamma(a + b.val())) * beta_val; | ||||||||||||||
| return make_callback_var(beta_val, [b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| b.adj() += vi.adj() * digamma_ab; | ||||||||||||||
| if constexpr (!is_constant<T1>::value) { | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to use the new |
||||||||||||||
| const auto& a_adj = adj_val * (digamma(a_array.val()) - digamma_ab); | ||||||||||||||
| if constexpr (is_stan_scalar<T1>::value) { | ||||||||||||||
| a_array.adj() += sum(a_adj); | ||||||||||||||
| } else { | ||||||||||||||
| a_array.adj() += a_adj; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| if constexpr (!is_constant<T2>::value) { | ||||||||||||||
| const auto& b_adj = adj_val * (digamma(b_array.val()) - digamma_ab); | ||||||||||||||
| if constexpr (is_stan_scalar<T2>::value) { | ||||||||||||||
| b_array.adj() += sum(b_adj); | ||||||||||||||
| } else { | ||||||||||||||
| b_array.adj() += b_adj; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| }); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| template <typename Mat1, typename Mat2, | ||||||||||||||
| require_any_var_matrix_t<Mat1, Mat2>* = nullptr, | ||||||||||||||
| require_all_matrix_t<Mat1, Mat2>* = nullptr> | ||||||||||||||
| inline auto beta(const Mat1& a, const Mat2& b) { | ||||||||||||||
| if constexpr (is_autodiff_v<Mat1> && is_autodiff_v<Mat2>) { | ||||||||||||||
| arena_t<promote_scalar_t<var, Mat1>> arena_a = a; | ||||||||||||||
| arena_t<promote_scalar_t<var, Mat2>> arena_b = b; | ||||||||||||||
| auto beta_val = beta(arena_a.val(), arena_b.val()); | ||||||||||||||
| auto digamma_ab | ||||||||||||||
| = to_arena(digamma(arena_a.val().array() + arena_b.val().array())); | ||||||||||||||
| return make_callback_var( | ||||||||||||||
| beta(arena_a.val(), arena_b.val()), | ||||||||||||||
| [arena_a, arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| const auto adj_val = (vi.adj().array() * vi.val().array()).eval(); | ||||||||||||||
| arena_a.adj().array() | ||||||||||||||
| += adj_val * (digamma(arena_a.val().array()) - digamma_ab); | ||||||||||||||
| arena_b.adj().array() | ||||||||||||||
| += adj_val * (digamma(arena_b.val().array()) - digamma_ab); | ||||||||||||||
| }); | ||||||||||||||
| } else if constexpr (is_autodiff_v<Mat1>) { | ||||||||||||||
| arena_t<promote_scalar_t<var, Mat1>> arena_a = a; | ||||||||||||||
| arena_t<promote_scalar_t<double, Mat2>> arena_b = value_of(b); | ||||||||||||||
| auto digamma_ab | ||||||||||||||
| = to_arena(digamma(arena_a.val()).array() | ||||||||||||||
| - digamma(arena_a.val().array() + arena_b.array())); | ||||||||||||||
| return make_callback_var(beta(arena_a.val(), arena_b), | ||||||||||||||
| [arena_a, arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| arena_a.adj().array() += vi.adj().array() | ||||||||||||||
| * digamma_ab | ||||||||||||||
| * vi.val().array(); | ||||||||||||||
| }); | ||||||||||||||
| } else if constexpr (is_autodiff_v<Mat2>) { | ||||||||||||||
| arena_t<promote_scalar_t<double, Mat1>> arena_a = value_of(a); | ||||||||||||||
| arena_t<promote_scalar_t<var, Mat2>> arena_b = b; | ||||||||||||||
| auto beta_val = beta(arena_a, arena_b.val()); | ||||||||||||||
| auto digamma_ab | ||||||||||||||
| = to_arena((digamma(arena_b.val()).array() | ||||||||||||||
| - digamma(arena_a.array() + arena_b.val().array())) | ||||||||||||||
| * beta_val.array()); | ||||||||||||||
| return make_callback_var( | ||||||||||||||
| beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| arena_b.adj().array() += vi.adj().array() * digamma_ab.array(); | ||||||||||||||
| }); | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| template <typename Scalar, typename VarMat, | ||||||||||||||
| require_var_matrix_t<VarMat>* = nullptr, | ||||||||||||||
| require_stan_scalar_t<Scalar>* = nullptr> | ||||||||||||||
| inline auto beta(const Scalar& a, const VarMat& b) { | ||||||||||||||
| if constexpr (is_autodiff_v<Scalar> && is_autodiff_v<VarMat>) { | ||||||||||||||
| var arena_a = a; | ||||||||||||||
| arena_t<promote_scalar_t<var, VarMat>> arena_b = b; | ||||||||||||||
| auto beta_val = beta(arena_a.val(), arena_b.val()); | ||||||||||||||
| auto digamma_ab = to_arena(digamma(arena_a.val() + arena_b.val().array())); | ||||||||||||||
| return make_callback_var( | ||||||||||||||
| beta(arena_a.val(), arena_b.val()), | ||||||||||||||
| [arena_a, arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| const auto adj_val = (vi.adj().array() * vi.val().array()).eval(); | ||||||||||||||
| arena_a.adj() | ||||||||||||||
| += (adj_val * (digamma(arena_a.val()) - digamma_ab)).sum(); | ||||||||||||||
| arena_b.adj().array() | ||||||||||||||
| += adj_val * (digamma(arena_b.val().array()) - digamma_ab); | ||||||||||||||
| }); | ||||||||||||||
| } else if constexpr (is_autodiff_v<Scalar>) { | ||||||||||||||
| var arena_a = a; | ||||||||||||||
| arena_t<promote_scalar_t<double, VarMat>> arena_b = value_of(b); | ||||||||||||||
| auto digamma_ab = to_arena(digamma(arena_a.val()) | ||||||||||||||
| - digamma(arena_a.val() + arena_b.array())); | ||||||||||||||
| return make_callback_var( | ||||||||||||||
| beta(arena_a.val(), arena_b), | ||||||||||||||
| [arena_a, arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| arena_a.adj() | ||||||||||||||
| += (vi.adj().array() * digamma_ab * vi.val().array()).sum(); | ||||||||||||||
| }); | ||||||||||||||
| } else if constexpr (is_autodiff_v<VarMat>) { | ||||||||||||||
| double arena_a = value_of(a); | ||||||||||||||
| arena_t<promote_scalar_t<var, VarMat>> arena_b = b; | ||||||||||||||
| auto beta_val = beta(arena_a, arena_b.val()); | ||||||||||||||
| auto digamma_ab = to_arena((digamma(arena_b.val()).array() | ||||||||||||||
| - digamma(arena_a + arena_b.val().array())) | ||||||||||||||
| * beta_val.array()); | ||||||||||||||
| return make_callback_var(beta_val, [arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| arena_b.adj().array() += vi.adj().array() * digamma_ab.array(); | ||||||||||||||
| }); | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| template <typename VarMat, typename Scalar, | ||||||||||||||
| require_var_matrix_t<VarMat>* = nullptr, | ||||||||||||||
| require_stan_scalar_t<Scalar>* = nullptr> | ||||||||||||||
| inline auto beta(const VarMat& a, const Scalar& b) { | ||||||||||||||
| if constexpr (is_autodiff_v<VarMat> && is_autodiff_v<Scalar>) { | ||||||||||||||
| arena_t<promote_scalar_t<var, VarMat>> arena_a = a; | ||||||||||||||
| var arena_b = b; | ||||||||||||||
| auto beta_val = beta(arena_a.val(), arena_b.val()); | ||||||||||||||
| auto digamma_ab = to_arena(digamma(arena_a.val().array() + arena_b.val())); | ||||||||||||||
| return make_callback_var( | ||||||||||||||
| beta(arena_a.val(), arena_b.val()), | ||||||||||||||
| [arena_a, arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| const auto adj_val = (vi.adj().array() * vi.val().array()).eval(); | ||||||||||||||
| arena_a.adj().array() | ||||||||||||||
| += adj_val * (digamma(arena_a.val().array()) - digamma_ab); | ||||||||||||||
| arena_b.adj() | ||||||||||||||
| += (adj_val * (digamma(arena_b.val()) - digamma_ab)).sum(); | ||||||||||||||
| }); | ||||||||||||||
| } else if constexpr (is_autodiff_v<VarMat>) { | ||||||||||||||
| arena_t<promote_scalar_t<var, VarMat>> arena_a = a; | ||||||||||||||
| double arena_b = value_of(b); | ||||||||||||||
| auto digamma_ab = to_arena(digamma(arena_a.val()).array() | ||||||||||||||
| - digamma(arena_a.val().array() + arena_b)); | ||||||||||||||
| return make_callback_var( | ||||||||||||||
| beta(arena_a.val(), arena_b), [arena_a, digamma_ab](auto& vi) mutable { | ||||||||||||||
| arena_a.adj().array() | ||||||||||||||
| += vi.adj().array() * digamma_ab * vi.val().array(); | ||||||||||||||
| }); | ||||||||||||||
| } else if constexpr (is_autodiff_v<Scalar>) { | ||||||||||||||
| arena_t<promote_scalar_t<double, VarMat>> arena_a = value_of(a); | ||||||||||||||
| var arena_b = b; | ||||||||||||||
| auto beta_val = beta(arena_a, arena_b.val()); | ||||||||||||||
| auto digamma_ab = to_arena( | ||||||||||||||
| (digamma(arena_b.val()) - digamma(arena_a.array() + arena_b.val())) | ||||||||||||||
| * beta_val.array()); | ||||||||||||||
| return make_callback_var( | ||||||||||||||
| beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable { | ||||||||||||||
| arena_b.adj() += (vi.adj().array() * digamma_ab.array()).sum(); | ||||||||||||||
| }); | ||||||||||||||
| } | ||||||||||||||
| return return_type_t(res); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| } // namespace math | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,8 +26,8 @@ TEST(MathFunctions, beta_vec) { | |
| auto f | ||
| = [](const auto& x1, const auto& x2) { return stan::math::beta(x1, x2); }; | ||
|
|
||
| Eigen::VectorXd in1 = Eigen::VectorXd::Random(6); | ||
| Eigen::VectorXd in2 = Eigen::VectorXd::Random(6); | ||
| Eigen::VectorXd in1 = Eigen::VectorXd::Random(6).cwiseAbs(); | ||
| Eigen::VectorXd in2 = Eigen::VectorXd::Random(6).cwiseAbs(); | ||
|
Comment on lines
+29
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did this have to change? |
||
|
|
||
| stan::test::binary_scalar_tester(f, in1, in2); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.