-
-
Notifications
You must be signed in to change notification settings - Fork 198
Use boost beta, simplify overloads #3212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
@SteveBronder what do you think of this structure for combining overloads with |
| const FvarInnerT beta_val = beta(a_val, b_val); | ||
| const FvarInnerT digamma_ab = digamma(a_val + b_val); | ||
| FvarInnerT beta_d(0); | ||
| if constexpr (!is_constant<Ta>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if constexpr (!is_constant<Ta>::value) { | |
| if constexpr (is_autodiff_v<Ta>) { |
| if constexpr (!is_constant<Ta>::value) { | ||
| beta_d += (digamma(a_val) - digamma_ab) * beta_val * a.d_; | ||
| } | ||
| if constexpr (!is_constant<Tb>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if constexpr (!is_constant<Tb>::value) { | |
| if constexpr (is_autodiff_v<Tb>) { |
| template < | ||
| typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr, | ||
| require_t<math::disjunction< | ||
| is_arithmetic<return_type_t<T1, T2>>, is_fvar<return_type_t<T1, T2>>, | ||
| is_std_vector<T1>, is_std_vector<T2>>>* = nullptr> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Little confused. Wouldn't is_fvar<return_type> conflict with the fvar version of beta?
| inline auto beta(const T1& a, const T2& b) { | ||
| arena_t<ref_type_t<T1>> arena_a = a; | ||
| arena_t<ref_type_t<T2>> arena_b = b; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| inline auto beta(const T1& a, const T2& b) { | |
| arena_t<ref_type_t<T1>> arena_a = a; | |
| arena_t<ref_type_t<T2>> arena_b = b; | |
| inline auto beta(T1&& a, T2&& b) { | |
| arena_t<ref_type_t<T1>> arena_a = std::forward<T1>(a); | |
| arena_t<ref_type_t<T2>> arena_b = std::forward<T2>(b); |
| const auto& beta_val = beta(value_of(arena_a), value_of(arena_b)); | ||
| using return_type_t = return_var_matrix_t<decltype(beta_val), T1, T2>; | ||
| arena_t<return_type_t> res(beta_val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| const auto& beta_val = beta(value_of(arena_a), value_of(arena_b)); | |
| using return_type_t = return_var_matrix_t<decltype(beta_val), T1, T2>; | |
| arena_t<return_type_t> res(beta_val); | |
| using inner_ret_t = decltype(beta(value_of(arena_a), value_of(arena_b))); | |
| using return_type_t = return_var_matrix_t<inner_ret_t, T1, T2>; | |
| arena_t<return_type_t> res(beta(value_of(arena_a), value_of(arena_b))); |
| auto&& b_array = as_array_or_scalar(arena_b); | ||
| const auto& res_array = as_array_or_scalar(res); | ||
| const auto& digamma_ab = digamma(value_of(a_array) + value_of(b_array)); | ||
| const auto& adj_val = res_array.adj() * res_array.val(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This produces a new expression objet so we don't need a const& of it
| const auto& adj_val = res_array.adj() * res_array.val(); | |
| auto adj_val = res_array.adj() * res_array.val(); |
Anywhere that we will produce a new Eigen expression it should just be an object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for several places here
| reverse_pass_callback([arena_a, arena_b, res]() mutable { | ||
| auto&& a_array = as_array_or_scalar(arena_a); | ||
| auto&& b_array = as_array_or_scalar(arena_b); | ||
| const auto& res_array = as_array_or_scalar(res); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use decltype(auto) here since as_array_or_scalar. decltype(auto) is allowed to be either a reference to an object or an object. The function as_array_or_scalar can either return a reference, for instance for a scalar or if an array is passed in as input and then given as output. It can also return a new object, like when we are return an array view of a matrix. That array view is an actual object.
| auto digamma_ab = (digamma(b.val()) - digamma(a + b.val())) * beta_val; | ||
| return make_callback_var(beta_val, [b, digamma_ab](auto& vi) mutable { | ||
| b.adj() += vi.adj() * digamma_ab; | ||
| if constexpr (!is_constant<T1>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to use the new is_autodiff_v everywhere you use !is_constant
| Eigen::VectorXd in1 = Eigen::VectorXd::Random(6).cwiseAbs(); | ||
| Eigen::VectorXd in2 = Eigen::VectorXd::Random(6).cwiseAbs(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did this have to change?
Summary
This PR replaces our manual implementation of the
betafunction with the one from Boost which is more accurate over a broader domainI've also combined the multiple overloads into a single overload per
prim,rev,fwdnow that we haveif constexpr.Tests
N/A
Side Effects
N/A
Release notes
Use Boost's implementation of
betafor improved stabilityChecklist
Copyright holder: Andrew Johnson
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit)make test-headers)make test-math-dependencies)make doxygen)make cpplint)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested