diff --git a/CHANGELOG.md b/CHANGELOG.md index 29069029c29..5cd748f6984 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,12 @@ ### Added (new features/APIs/variables/...) - [[PR556]](https://github.com/lanl/singularity-eos/pull/556) Add introspection into types available in the variant +- [[PR564]](https://github.com/lanl/singularity-eos/pull/564) Removed Get() function from IndexableTypes since it could have unexpected consequences when a type wasn't present ### Fixed (Repair bugs, etc) - [[PR561]](https://github.com/lanl/singularity-eos/pull/561) Fix logic for kokkos-kernels in spackage so that it is only required for closure models on GPU - [[PR563]](https://github.com/lanl/singularity-eos/pull/563) Fixed DensityFromPressureTemperature for the Carnahan-Starling EOS. +- [[PR564]](https://github.com/lanl/singularity-eos/pull/564) Fix logic for numerical vs type indices by adding safeGet(), safeSet(), safeMustGet(), and safeMustSet() helpers ### Changed (changing behavior/API/variables/...) diff --git a/doc/sphinx/src/using-eos.rst b/doc/sphinx/src/using-eos.rst index f5e26b924d2..326c4173367 100644 --- a/doc/sphinx/src/using-eos.rst +++ b/doc/sphinx/src/using-eos.rst @@ -699,6 +699,8 @@ of the ``[]`` operator that takes your type. For example: class MyLambda_t { public: + // Enable recognition that this is type-indexable + constexpr static bool is_type_indexable = true; MyLambda_t() = default; PORTABLE_FORCEINLINE_FUNCTION Real &operator[](const std::size_t idx) const { @@ -724,17 +726,92 @@ which might be used as where ``MeanIonizationState`` is shorthand for index 2, since you defined that overload. Note that the ``operator[]`` must be marked ``const``. To more easily enable mixing and matching integer-based -indexing with type-based indexing, the function +indexing with type-based indexing, the functions .. code-block:: cpp - template + template PORTABLE_FORCEINLINE_FUNCTION - Real &Get(Indexer_t &&lambda, std::size_t idx = 0); + bool SafeGet(Indexer_t const &lambda, std::size_t const idx, Real &out); -will return a reference to the value at named index ``Name_t()`` if -that overload is defined in ``Indexer_t`` and otherwise return a -reference at index ``idx``. +.. code-block:: cpp + + template + PORTABLE_FORCEINLINE_FUNCTION + bool SafeGet(Indexer_t const &lambda, Real &out); + +.. code-block:: cpp + + template + PORTABLE_FORCEINLINE_FUNCTION + Real SafeMustGet(Indexer_t const &lambda, std::size_t const idx) + +.. code-block:: cpp + + template + PORTABLE_FORCEINLINE_FUNCTION + Real SafeMustGet(Indexer_t const &lambda) + +will update the value of ``out`` with the value at either the appropriate +type-based index, ``T``, or the numerical index, ``idx``, if the ``Indexer_t`` +doesn't accept type-based indexing. If the ``Indexer_t`` **does** accept +type-based indexing but **doesn't** have the requested index, then the +``out`` value is not updated. The same is true for when ``Indexer_t`` is the +``nullptr``. The overload that doesn't take a numerical index will *only* +return the value at a type-based index. + +The ``SafeMustGet()`` version is intended to generate errors if a value can't be +retrieved with four types of errors that can occur: + +1. If a null pointer is passed as the indexer, a runtime abort or exception will + occur +2. If a type-based indexer is passed (i.e. one with the ``constexpr static bool`` + member ``is_type_indexable = true``), but the type doesn't exist then a **static** + assertion will fail +3. If one of the overloads is used where a ``std::size_t`` index is provided, but + the indexer can't accept integer indexing, then a **static** assertion will fail +4. If the indexer can't use type-based indexing but an ``std::size_t`` index + wasn't provided, then a **static** assertion will fail + +Similarly, the functions + +.. code-block:: cpp + + template + PORTABLE_FORCEINLINE_FUNCTION + inline bool SafeSet(Indexer_t &lambda, std::size_t const idx, Real const in); + +.. code-block:: cpp + + template + PORTABLE_FORCEINLINE_FUNCTION + inline bool SafeSet(Indexer_t &lambda, Real const in) + +.. code-block:: cpp + + template + PORTABLE_FORCEINLINE_FUNCTION + inline bool SafeMustSet(Indexer_t &lambda, std::size_t const idx, Real const in); + +.. code-block:: cpp + + template + PORTABLE_FORCEINLINE_FUNCTION + inline bool SafeMustSet(Indexer_t &lambda, Real const in) + +can modify the values in the ``Indexer_t`` and behave the same way. In this way, +if a type-based index isn't present in the container, then a different index won't +be overwritten. Again, the ``SafeMustSet()`` version will compile-time fail +or runtime abort if the lambda value can't be modified for the same reasons as +``SafeMustGet()``. + +.. note:: + + Previous versions defined a ``Get()`` function that was "unsafe" in the + sense that it would fall back on the numerical index even if a type-based + indexer was used. This could result in retrieving and overwriting incorrect + values in the indexer. We recommend not using this function and instead using + the "safe" versions. As a convenience tool, the struct diff --git a/singularity-eos/base/indexable_types.hpp b/singularity-eos/base/indexable_types.hpp index 25cc4928d95..9f6d4737f0e 100644 --- a/singularity-eos/base/indexable_types.hpp +++ b/singularity-eos/base/indexable_types.hpp @@ -20,45 +20,231 @@ #include #include +#include #include namespace singularity { namespace IndexerUtils { -// Convenience function for accessing an indexer by either type or -// natural number index depending on what is available -template -PORTABLE_FORCEINLINE_FUNCTION auto &Get(Indexer_t &&lambda, std::size_t idx = 0) { + +// Identifies an indexer as a type-based indexer +template +struct is_type_indexer : std::false_type {}; +template +struct is_type_indexer::is_type_indexable)>> + : std::bool_constant::is_type_indexable> {}; +template +constexpr bool is_type_indexer_v = is_type_indexer::value; + +namespace impl { + +// Simple way to switch between pure type indexing or also allowing intergers +enum class AllowedIndexing { Numeric, TypeOnly }; + +// The "safe" version of Get(). This function will ONLY return a value IF that +// type-based index is present in the Indexer OR if the Indexer doesn't support +// type-based indexing. +template +PORTABLE_FORCEINLINE_FUNCTION bool SafeGet(Indexer_t const &lambda, std::size_t const idx, + Real &out) { + // If null then nothing happens + if (variadic_utils::is_nullptr(lambda)) { + return false; + } + + // Return value if type index is available if constexpr (variadic_utils::is_indexable_v) { - return lambda[T()]; + out = lambda[T{}]; + return true; + } + + // Do nothing if lambda has type indexing BUT doesn't have this type index + if constexpr (is_type_indexer_v) { + return false; + } + + // Fall back to numeric indexing if allowed + if constexpr (AI == AllowedIndexing::Numeric) { + if constexpr (variadic_utils::has_int_index_v) { + out = lambda[idx]; + return true; + } + } + + // Something else... + return false; +} + +// Break out "Set" functionality from "Get". The original "Get()" did both, but +// the "safe" version needs to separate that functionality for setting the +// values in a lambda +template +PORTABLE_FORCEINLINE_FUNCTION bool SafeSet(Indexer_t &lambda, std::size_t const idx, + Real const in) { + // If null then nothing happens + if (variadic_utils::is_nullptr(lambda)) { + return false; + } + + // Return value if type index is available + if constexpr (variadic_utils::is_indexable_v) { + lambda[T{}] = in; + return true; + } + + // Do nothing if lambda has type indexing BUT doesn't have this type index + if constexpr (is_type_indexer_v) { + return false; + } + + // Fall back to numeric indexing if allowed + if constexpr (AI == AllowedIndexing::Numeric) { + if constexpr (variadic_utils::has_int_index_v) { + lambda[idx] = in; + return true; + } + } + + // Something else... + return false; +} + +// Same as above but causes an error condition (static or dynamic) if the value +// can't be obtained. Note that the `decltype(auto)` is intended to preserve the +// value category of the square bracket operator of the `Indexer_t` type. This +// allows references to be returned since there is also no possibility of the +// call doing nothing (i.e. like SafeGet and SafeSet), and thus it can be used +// for either setting or getting values. This should also allow for `const` +// correctness downstream in the wrappers where `lambda` is `const` +template +PORTABLE_FORCEINLINE_FUNCTION decltype(auto) SafeMustGetSet(Indexer_t &&lambda, + std::size_t const idx) { + // Error on null pointer + PORTABLE_ALWAYS_REQUIRE(!variadic_utils::is_nullptr(lambda), + "Indexer can't be nullptr"); + + if constexpr (is_type_indexer_v) { + // Return type-based index. Static assert that type MUST exist in indexer + static_assert(variadic_utils::is_indexable_v); + // Use std::forward to maintain value category for lambda, and use + // parentheses to do the same for the output of the lambda[] operation + return (std::forward(lambda)[T{}]); + } else if constexpr (AI == AllowedIndexing::Numeric) { + // Fall back to numerical indexing if allowed + static_assert(variadic_utils::has_int_index_v); + // Use std::forward to maintain value category for lambda, and use + // parentheses to do the same for the output of the lambda[] operation + return (std::forward(lambda)[idx]); } else { - return lambda[idx]; + // Something else that can't be compiled... + static_assert(variadic_utils::dependent_false_v, + "Indexer must either be designated as type-based through a " + "`is_type_indexable` boolean data member or SafeGet/SafeSet function " + "must be called with a numerical index"); } } +} // namespace impl + +// Overload when numerical index is provided +template +PORTABLE_FORCEINLINE_FUNCTION bool SafeGet(Indexer_t const &lambda, std::size_t const idx, + Real &out) { + return impl::SafeGet(lambda, idx, out); +} + +// Overload when numerical index isn't provided +template +PORTABLE_FORCEINLINE_FUNCTION bool SafeGet(Indexer_t const &lambda, Real &out) { + std::size_t idx = 0; + return impl::SafeGet(lambda, idx, out); +} + +// Overload when numerical index is provided +template +PORTABLE_FORCEINLINE_FUNCTION bool SafeSet(Indexer_t &lambda, std::size_t const idx, + Real const in) { + return impl::SafeSet(lambda, idx, in); +} + +// Overload when numerical index isn't provided +template +PORTABLE_FORCEINLINE_FUNCTION bool SafeSet(Indexer_t &lambda, Real const in) { + std::size_t idx = 0; + return impl::SafeSet(lambda, idx, in); +} + +// Overload when numerical index is provided +template +PORTABLE_FORCEINLINE_FUNCTION Real SafeMustGet(Indexer_t const &lambda, + std::size_t const idx) { + return impl::SafeMustGetSet(lambda, idx); +} + +// Overload when numerical index isn't provided +template +PORTABLE_FORCEINLINE_FUNCTION Real SafeMustGet(Indexer_t const &lambda) { + std::size_t idx = 0; + return impl::SafeMustGetSet(lambda, idx); +} + +// Overload when numerical index is provided +template +PORTABLE_FORCEINLINE_FUNCTION void SafeMustSet(Indexer_t &lambda, std::size_t const idx, + Real const in) { + impl::SafeMustGetSet(lambda, idx) = in; +} + +// Overload when numerical index isn't provided +template +PORTABLE_FORCEINLINE_FUNCTION void SafeMustSet(Indexer_t &lambda, Real const in) { + std::size_t idx = 0; + impl::SafeMustGetSet(lambda, idx) = in; +} + // This is a convenience struct to easily build a small indexer with // a set of indexable types. template class VariadicIndexerBase { public: + // Any class that wants to be recognized as indexable (so that we don't + // accidentally fall back to integer indexing when we don't want to) needs to + // include this. + constexpr static bool is_type_indexable = true; + + // JHP: another option for the `is_type_indexable` flag is to take the ADL + // route. Essentially this would involve defining a friend function that + // could be defined in an appropriate namesapce so that theoretically a host + // code could use a TPL container with type-based indexing and allow that + // container to be flagged in our code as acceptable. This seems like a bit + // of a heavy hammer for what we need here though. We can easily change this + // if a TPL provides a type that is being used for this purpose. + VariadicIndexerBase() = default; + PORTABLE_FORCEINLINE_FUNCTION VariadicIndexerBase(const Data_t &data) : data_(data) {} + template ::value>> PORTABLE_FORCEINLINE_FUNCTION Real &operator[](const T &t) { constexpr std::size_t idx = variadic_utils::GetIndexInTL(); return data_[idx]; } + PORTABLE_FORCEINLINE_FUNCTION Real &operator[](const std::size_t idx) { return data_[idx]; } + template ::value>> PORTABLE_FORCEINLINE_FUNCTION const Real &operator[](const T &t) const { constexpr std::size_t idx = variadic_utils::GetIndexInTL(); return data_[idx]; } + PORTABLE_FORCEINLINE_FUNCTION const Real &operator[](const std::size_t idx) const { return data_[idx]; } + static inline constexpr std::size_t size() { return sizeof...(Ts); } private: @@ -70,6 +256,7 @@ using VariadicIndexer = VariadicIndexerBase, Ts. // uses a Real* template using VariadicPointerIndexer = VariadicIndexerBase; + } // namespace IndexerUtils namespace IndexableTypes { diff --git a/singularity-eos/base/variadic_utils.hpp b/singularity-eos/base/variadic_utils.hpp index fc282e4c6c8..8b30c0cb5b2 100644 --- a/singularity-eos/base/variadic_utils.hpp +++ b/singularity-eos/base/variadic_utils.hpp @@ -24,6 +24,16 @@ namespace variadic_utils { // Some generic variatic utilities // ====================================================================== +// Template parameter dependent boolean suitable for causing static_assert +// errors within `if constexpr` branches. Essentially the issue is that if the +// static_assert _always_ evaluates to false, then it will _always_ cause a +// compile time error even if that branch of the code will never be reached. +// Making the evaluation (superficially) dependent on the template deduction +// causes it to be evaluated after the `if constexpr` branching has already been +// determined. See https://en.cppreference.com/w/cpp/language/if.html#Constexpr_if +template +inline constexpr bool dependent_false_v = false; + // Useful for generating nullptr of a specific pointer type template inline constexpr T *np() { @@ -111,6 +121,15 @@ struct is_indexable constexpr bool is_indexable_v = is_indexable::value; +// Check if a type can accept an int index +template +struct has_int_index : std::false_type {}; +template +struct has_int_index()[std::declval()])>> + : std::true_type {}; +template +constexpr bool has_int_index_v = has_int_index::value; + // this flattens a typelist of typelists to a single typelist // first parameter - accumulator diff --git a/singularity-eos/eos/eos_electrons.hpp b/singularity-eos/eos/eos_electrons.hpp index 0f5c262c6d0..7ac3f61842d 100644 --- a/singularity-eos/eos/eos_electrons.hpp +++ b/singularity-eos/eos/eos_electrons.hpp @@ -209,8 +209,8 @@ class IdealElectrons : public EosBase { private: template PORTABLE_INLINE_FUNCTION Real _Cv(Indexer_t &&lambda) const { - const Real Z = - IndexerUtils::Get(lambda, Lambda::Zi); + Real Z = IndexerUtils::SafeMustGet(lambda, + Lambda::Zi); return _Cvbase * std::max(Z, static_cast(0.0)); } diff --git a/singularity-eos/eos/eos_helmholtz.hpp b/singularity-eos/eos/eos_helmholtz.hpp index 08157606b45..665bf5925ed 100644 --- a/singularity-eos/eos/eos_helmholtz.hpp +++ b/singularity-eos/eos/eos_helmholtz.hpp @@ -630,8 +630,10 @@ class Helmholtz : public EosBase { Indexer_t &&lambda = static_cast(nullptr)) const { using namespace HelmUtils; Real p[NDERIV], e[NDERIV], s[NDERIV], etaele[NDERIV], nep[NDERIV]; - Real abar = IndexerUtils::Get(lambda, Lambda::Abar); - Real zbar = IndexerUtils::Get(lambda, Lambda::Zbar); + Real abar = + IndexerUtils::SafeMustGet(lambda, Lambda::Abar); + Real zbar = + IndexerUtils::SafeMustGet(lambda, Lambda::Zbar); Real ytot, ye, ywot, De, lDe; GetElectronDensities_(rho, abar, zbar, ytot, ye, ywot, De, lDe); Real lT = lTFromRhoSie_(rho, sie, abar, zbar, ye, ytot, ywot, De, lDe, lambda); @@ -657,14 +659,16 @@ class Helmholtz : public EosBase { const Real rho, const Real T, Indexer_t &&lambda = static_cast(nullptr)) const { using namespace HelmUtils; - return IndexerUtils::Get(lambda, Lambda::Abar); + return IndexerUtils::SafeMustGet(lambda, + Lambda::Abar); } template PORTABLE_INLINE_FUNCTION Real MeanAtomicNumberFromDensityTemperature( const Real rho, const Real T, Indexer_t &&lambda = static_cast(nullptr)) const { using namespace HelmUtils; - return IndexerUtils::Get(lambda, Lambda::Zbar); + return IndexerUtils::SafeMustGet(lambda, + Lambda::Zbar); } template @@ -758,10 +762,12 @@ class Helmholtz : public EosBase { GetFromDensityTemperature_(const Real rho, const Real temperature, Indexer_t &&lambda, Real p[NDERIV], Real e[NDERIV], Real s[NDERIV], Real etaele[NDERIV], Real nep[NDERIV]) const { - Real abar = IndexerUtils::Get(lambda, Lambda::Abar); - Real zbar = IndexerUtils::Get(lambda, Lambda::Zbar); + Real abar = + IndexerUtils::SafeMustGet(lambda, Lambda::Abar); + Real zbar = + IndexerUtils::SafeMustGet(lambda, Lambda::Zbar); Real lT = std::log10(temperature); - IndexerUtils::Get(lambda, Lambda::lT) = lT; + IndexerUtils::SafeGet(lambda, Lambda::lT, lT); Real ytot, ye, ywot, De, lDe; GetElectronDensities_(rho, abar, zbar, ytot, ye, ywot, De, lDe); GetFromDensityLogTemperature_(rho, temperature, abar, zbar, ye, ytot, ywot, De, lDe, @@ -773,8 +779,10 @@ class Helmholtz : public EosBase { GetFromDensityInternalEnergy_(const Real rho, const Real sie, Indexer_t &&lambda, Real p[NDERIV], Real e[NDERIV], Real s[NDERIV], Real etaele[NDERIV], Real nep[NDERIV]) const { - Real abar = IndexerUtils::Get(lambda, Lambda::Abar); - Real zbar = IndexerUtils::Get(lambda, Lambda::Zbar); + Real abar = + IndexerUtils::SafeMustGet(lambda, Lambda::Abar); + Real zbar = + IndexerUtils::SafeMustGet(lambda, Lambda::Zbar); Real ytot, ye, ywot, De, lDe; GetElectronDensities_(rho, abar, zbar, ytot, ye, ywot, De, lDe); Real lT = lTFromRhoSie_(rho, sie, abar, zbar, ye, ytot, ywot, De, lDe, lambda); @@ -819,8 +827,10 @@ Helmholtz::FillEos(Real &rho, Real &temp, Real &energy, Real &press, Real &cv, R PORTABLE_ALWAYS_REQUIRE( !(need_temp && need_sie), "Either specific internal energy or temperature must be provided."); - Real abar = IndexerUtils::Get(lambda, Lambda::Abar); - Real zbar = IndexerUtils::Get(lambda, Lambda::Zbar); + Real abar = + IndexerUtils::SafeMustGet(lambda, Lambda::Abar); + Real zbar = + IndexerUtils::SafeMustGet(lambda, Lambda::Zbar); Real ytot, ye, ywot, De, lDe, lT; GetElectronDensities_(rho, abar, zbar, ytot, ye, ywot, De, lDe); if (need_temp) { @@ -828,7 +838,7 @@ Helmholtz::FillEos(Real &rho, Real &temp, Real &energy, Real &press, Real &cv, R temp = math_utils::pow10(lT); } else { lT = std::log10(temp); - IndexerUtils::Get(lambda, Lambda::lT) = lT; + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); } Real p[NDERIV], e[NDERIV], s[NDERIV], etaele[NDERIV], nep[NDERIV]; GetFromDensityLogTemperature_(rho, temp, abar, zbar, ye, ytot, ywot, De, lDe, p, e, s, @@ -870,7 +880,8 @@ PORTABLE_INLINE_FUNCTION Real Helmholtz::lTFromRhoSie_(const Real rho, const Rea if (options_.ENABLE_RAD || options_.GAS_DEGENERATE || options_.ENABLE_COULOMB_CORRECTIONS) { - Real lTguess = IndexerUtils::Get(lambda, Lambda::lT); + Real lTguess; + IndexerUtils::SafeGet(lambda, Lambda::lT, lTguess); if (!((electrons_.lTMin() <= lTguess) && (lTguess <= electrons_.lTMax()))) { lTguess = lTAnalytic_(rho, e, ni, ne); if (!((electrons_.lTMin() <= lTguess) && (lTguess <= electrons_.lTMax()))) { @@ -944,7 +955,7 @@ PORTABLE_INLINE_FUNCTION Real Helmholtz::lTFromRhoSie_(const Real rho, const Rea } lT = electrons_.lTMax(); } - IndexerUtils::Get(lambda, Lambda::lT) = lT; + IndexerUtils::SafeGet(lambda, Lambda::lT, lT); return lT; } diff --git a/singularity-eos/eos/eos_spiner_rho_sie.hpp b/singularity-eos/eos/eos_spiner_rho_sie.hpp index dd54c5b22e1..8c0a035b042 100644 --- a/singularity-eos/eos/eos_spiner_rho_sie.hpp +++ b/singularity-eos/eos/eos_spiner_rho_sie.hpp @@ -687,9 +687,7 @@ PORTABLE_INLINE_FUNCTION void SpinerEOSDependsRhoSieTransformable: } } else { lRho = to_log(rho, lRhoOffset_); - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); } if (output & thermalqs::temperature) { @@ -747,9 +745,7 @@ SpinerEOSDependsRhoSieTransformable::interpRhoT_(const Real rho, Indexer_t &&lambda) const { const Real lRho = spiner_common::to_log(rho, lRhoOffset_); const Real lT = spiner_common::to_log(T, lTOffset_); - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); return db.interpToReal(lRho, lT); } @@ -761,9 +757,7 @@ SpinerEOSDependsRhoSieTransformable::interpRhoSie_( const Real lRho = spiner_common::to_log(rho, lRhoOffset_); const Real sie_transformed = transformer_.transform(sie, rho); const Real lE = spiner_common::to_log(sie_transformed, lEOffset_); - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); return db.interpToReal(lRho, lE); } @@ -786,12 +780,10 @@ SpinerEOSDependsRhoSieTransformable::lRhoFromPlT_( } } else { Real lRhoGuess = reproducible_ ? lRhoMin_ : 0.5 * (lRhoMin_ + lRhoMax_); - if (!variadic_utils::is_nullptr(lambda)) { - Real lRho_cache = - IndexerUtils::Get(lambda, Lambda::lRho); - if ((lRhoMin_ <= lRho_cache) && (lRho_cache <= lRhoMax_)) { - lRhoGuess = lRho_cache; - } + Real lRho_cache; + IndexerUtils::SafeGet(lambda, Lambda::lRho, lRho_cache); + if ((lRhoMin_ <= lRho_cache) && (lRho_cache <= lRhoMax_)) { + lRhoGuess = lRho_cache; } const callable_interp::l_interp PFunc(dependsRhoT_.P, lT); status = SP_ROOT_FINDER(PFunc, P, lRhoGuess, lRhoMin_, lRhoMax_, robust::EPS(), @@ -809,12 +801,8 @@ SpinerEOSDependsRhoSieTransformable::lRhoFromPlT_( lRho = reproducible_ ? lRhoMin_ : lRhoGuess; } } - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - if constexpr (variadic_utils::is_indexable_v) { - lambda[IndexableTypes::RootStatus()] = static_cast(status); - } - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); + IndexerUtils::SafeSet(lambda, static_cast(status)); return lRho; } diff --git a/singularity-eos/eos/eos_spiner_rho_temp.hpp b/singularity-eos/eos/eos_spiner_rho_temp.hpp index d6bf1b9d408..3015fbd7a55 100644 --- a/singularity-eos/eos/eos_spiner_rho_temp.hpp +++ b/singularity-eos/eos/eos_spiner_rho_temp.hpp @@ -880,10 +880,8 @@ SpinerEOSDependsRhoT::FillEos(Real &rho, Real &temp, Real &energy, Real &press, if (output & thermalqs::bulk_modulus) { bmod = bModFromRholRhoTlT_(rho, lRho, temp, lT, whereAmI); } - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - IndexerUtils::Get(lambda, Lambda::lT) = lT; - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); } template @@ -914,10 +912,8 @@ SpinerEOSDependsRhoT::getLogsRhoT_(const Real rho, const Real temperature, Real Real &lT, Indexer_t &&lambda) const { lRho = lRho_(rho); lT = lT_(temperature); - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - IndexerUtils::Get(lambda, Lambda::lT) = lT; - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); } template @@ -931,11 +927,10 @@ PORTABLE_INLINE_FUNCTION Real SpinerEOSDependsRhoT::lRhoFromPlT_( const RootFinding1D::RootCounts *pcounts = (memoryStatus_ == DataStatus::OnDevice) ? nullptr : &counts; - if (!variadic_utils::is_nullptr(lambda)) { - Real lRho_cache = IndexerUtils::Get(lambda, Lambda::lRho); - if ((lRhoMin_ <= lRho_cache) && (lRho_cache <= lRhoMax_)) { - lRhoGuess = lRho_cache; - } + Real lRho_cache; + IndexerUtils::SafeGet(lambda, Lambda::lRho, lRho_cache); + if ((lRhoMin_ <= lRho_cache) && (lRho_cache <= lRhoMax_)) { + lRhoGuess = lRho_cache; } if (lT <= lTMin_) { // cold curve @@ -972,17 +967,11 @@ PORTABLE_INLINE_FUNCTION Real SpinerEOSDependsRhoT::lRhoFromPlT_( #endif // SPINER_EOS_VERBOSE lRho = reproducible_ ? lRhoMax_ : lRhoGuess; } - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - IndexerUtils::Get(lambda, Lambda::lT) = lT; - if constexpr (variadic_utils::is_indexable_v) { - lambda[IndexableTypes::RootStatus()] = static_cast(status); - } - if constexpr (variadic_utils::is_indexable_v) { - lambda[IndexableTypes::TableStatus()] = static_cast(whereAmI); - } - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); + // No numerical index: only set if type-based indexing is used + IndexerUtils::SafeSet(lambda, static_cast(status)); + IndexerUtils::SafeSet(lambda, static_cast(whereAmI)); return lRho; } @@ -1013,12 +1002,10 @@ PORTABLE_INLINE_FUNCTION Real SpinerEOSDependsRhoT::lTFromlRhoSie_( } } else { Real lTGuess = reproducible_ ? lTMin_ : 0.5 * (lTMin_ + lTMax_); - if (!variadic_utils::is_nullptr(lambda)) { - Real lT_cache = - IndexerUtils::Get(lambda, Lambda::lT); - if ((lTMin_ <= lT_cache) && (lT_cache <= lTMax_)) { - lTGuess = lT_cache; - } + Real lT_cache; + IndexerUtils::SafeGet(lambda, Lambda::lT, lT_cache); + if ((lTMin_ <= lT_cache) && (lT_cache <= lTMax_)) { + lTGuess = lT_cache; } const callable_interp::r_interp sieFunc(sie_, lRho); status = SP_ROOT_FINDER(sieFunc, sie, lTGuess, lTMin_, lTMax_, ROOT_THRESH, @@ -1040,17 +1027,11 @@ PORTABLE_INLINE_FUNCTION Real SpinerEOSDependsRhoT::lTFromlRhoSie_( lT = reproducible_ ? lTMin_ : lTGuess; } } - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - IndexerUtils::Get(lambda, Lambda::lT) = lT; - if constexpr (variadic_utils::is_indexable_v) { - lambda[IndexableTypes::RootStatus()] = static_cast(status); - } - if constexpr (variadic_utils::is_indexable_v) { - lambda[IndexableTypes::TableStatus()] = static_cast(whereAmI); - } - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); + // No numerical index: only set if type-based indexing is used + IndexerUtils::SafeSet(lambda, static_cast(status)); + IndexerUtils::SafeSet(lambda, static_cast(whereAmI)); return lT; } @@ -1080,11 +1061,10 @@ PORTABLE_INLINE_FUNCTION Real SpinerEOSDependsRhoT::lTFromlRhoP_( } else { whereAmI = TableStatus::OnTable; lTGuess = 0.5 * (lTMin_ + lTMax_); - if (!variadic_utils::is_nullptr(lambda)) { - Real lT_cache = IndexerUtils::Get(lambda, lT); - if ((lTMin_ <= lT_cache) && (lT_cache <= lTMax_)) { - lTGuess = lT_cache; - } + Real lT_cache; + IndexerUtils::SafeGet(lambda, lT, lT_cache); + if ((lTMin_ <= lT_cache) && (lT_cache <= lTMax_)) { + lTGuess = lT_cache; } const callable_interp::r_interp PFunc(P_, lRho); status = SP_ROOT_FINDER(PFunc, press, lTGuess, lTMin_, lTMax_, ROOT_THRESH, @@ -1102,17 +1082,11 @@ PORTABLE_INLINE_FUNCTION Real SpinerEOSDependsRhoT::lTFromlRhoP_( lT = reproducible_ ? lTMin_ : lTGuess; } } - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, Lambda::lRho) = lRho; - IndexerUtils::Get(lambda, Lambda::lT) = lT; - if constexpr (variadic_utils::is_indexable_v) { - lambda[IndexableTypes::RootStatus()] = static_cast(status); - } - if constexpr (variadic_utils::is_indexable_v) { - lambda[IndexableTypes::TableStatus()] = static_cast(whereAmI); - } - } + IndexerUtils::SafeSet(lambda, Lambda::lRho, lRho); + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); + // No numerical index: only set if type-based indexing is used + IndexerUtils::SafeSet(lambda, static_cast(status)); + IndexerUtils::SafeSet(lambda, static_cast(whereAmI)); return lT; } diff --git a/singularity-eos/eos/eos_stellar_collapse.hpp b/singularity-eos/eos/eos_stellar_collapse.hpp index 6f98e252741..eaa79e03b88 100644 --- a/singularity-eos/eos/eos_stellar_collapse.hpp +++ b/singularity-eos/eos/eos_stellar_collapse.hpp @@ -372,8 +372,8 @@ class StellarCollapse : public EosBase { checkLambda_(lambda); lRho = lRho_(rho); lT = lT_(temp); - Ye = IndexerUtils::Get(lambda, Lambda::Ye); - IndexerUtils::Get(lambda, Lambda::lT) = lT; + Ye = IndexerUtils::SafeMustGet(lambda, Lambda::Ye); + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); } template PORTABLE_INLINE_FUNCTION __attribute__((always_inline)) void @@ -381,7 +381,7 @@ class StellarCollapse : public EosBase { Real &lT, Real &Ye) const noexcept { lRho = lRho_(rho); lT = lTFromlRhoSie_(lRho, sie, lambda); - Ye = IndexerUtils::Get(lambda, Lambda::Ye); + Ye = IndexerUtils::SafeMustGet(lambda, Lambda::Ye); return; } @@ -577,7 +577,8 @@ template PORTABLE_INLINE_FUNCTION Real StellarCollapse::MinInternalEnergyFromDensity(const Real rho, Indexer_t &&lambda) const { Real lRho = lRho_(rho); - Real Ye = IndexerUtils::Get(lambda, Lambda::Ye); + Real Ye; + Ye = IndexerUtils::SafeMustGet(lambda, Lambda::Ye); return eCold_.interpToReal(Ye, lRho); } @@ -667,7 +668,8 @@ PORTABLE_INLINE_FUNCTION void StellarCollapse::DensityEnergyFromPressureTemperat Real lrguess = lRho_(rho); Real lT = lT_(temp); Real lP = P2lP_(press); - Real Ye = IndexerUtils::Get(lambda, Lambda::Ye); + Real Ye = + IndexerUtils::SafeMustGet(lambda, Lambda::Ye); if ((lrguess < lRhoMin_) || (lrguess > lRhoMax_)) { lrguess = lRho_(rhoNormal_); @@ -680,7 +682,7 @@ PORTABLE_INLINE_FUNCTION void StellarCollapse::DensityEnergyFromPressureTemperat Real lE = lE_.interpToReal(Ye, lT, lrguess); rho = rho_(lrguess); sie = le2e_(lE); - IndexerUtils::Get(lambda, Lambda::lT) = lT; + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); } template @@ -762,8 +764,9 @@ StellarCollapse::ValuesAtReferenceState(Real &rho, Real &temp, Real &sie, Real & dpde = dPdENormal_; dvdt = dVdTNormal_; Real lT = lT_(temp); - IndexerUtils::Get(lambda, Lambda::Ye) = YeNormal_; - IndexerUtils::Get(lambda, Lambda::lT) = lT; + IndexerUtils::SafeMustSet(lambda, Lambda::Ye, + YeNormal_); + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); } inline void StellarCollapse::LoadFromSP5File_(const std::string &filename) { @@ -1177,8 +1180,10 @@ PORTABLE_INLINE_FUNCTION Real StellarCollapse::lTFromlRhoSie_( RootFinding1D::Status status = RootFinding1D::Status::SUCCESS; using RootFinding1D::regula_falsi; Real lT; - Real Ye = IndexerUtils::Get(lambda, Lambda::Ye); - Real lTGuess = IndexerUtils::Get(lambda, Lambda::lT); + Real Ye = + IndexerUtils::SafeMustGet(lambda, Lambda::Ye); + Real lTGuess; + IndexerUtils::SafeGet(lambda, Lambda::lT, lTGuess); const RootFinding1D::RootCounts *pcounts = (memoryStatus_ == DataStatus::OnDevice) ? nullptr : &counts; @@ -1224,7 +1229,7 @@ PORTABLE_INLINE_FUNCTION Real StellarCollapse::lTFromlRhoSie_( status_ = status; } #endif // PORTABILITY_STRATEGY_NONE - IndexerUtils::Get(lambda, Lambda::lT) = lT; + IndexerUtils::SafeSet(lambda, Lambda::lT, lT); return lT; } } // namespace singularity diff --git a/singularity-eos/eos/modifiers/zsplit_eos.hpp b/singularity-eos/eos/modifiers/zsplit_eos.hpp index 6192f17c63a..ce850acd4e4 100644 --- a/singularity-eos/eos/modifiers/zsplit_eos.hpp +++ b/singularity-eos/eos/modifiers/zsplit_eos.hpp @@ -275,14 +275,8 @@ class ZSplit : public EosBase> { template PORTABLE_FORCEINLINE_FUNCTION Real GetIonizationState_(Indexer_t &&lambda) const { using namespace variadic_utils; - if (is_nullptr(lambda)) { - PORTABLE_THROW_OR_ABORT("ZSplitEOS: lambda must contain mean ionization state!\n"); - } - if constexpr (is_indexable_v) { - return std::max(0.0, lambda[IndexableTypes::MeanIonizationState()]); - } else { - return std::max(0.0, lambda[T::nlambda()]); - } + return std::max(0.0, IndexerUtils::SafeMustGet( + lambda, T::nlambda())); } // TODO(JMM): Runtime? template diff --git a/test/test_indexable_types.cpp b/test/test_indexable_types.cpp index d4b2231e731..9b83e251688 100644 --- a/test/test_indexable_types.cpp +++ b/test/test_indexable_types.cpp @@ -21,6 +21,9 @@ #define CATCH_CONFIG_FAST_COMPILE #include #endif +#include + +#include using namespace singularity::IndexerUtils; using namespace singularity::IndexableTypes; @@ -29,14 +32,22 @@ using Lambda_t = VariadicIndexer data_; + std::array data_; +}; + +class NewManualLambda_t : public ManualLambda_t { + public: + // Enable recognition that this is type-indexable + constexpr static bool is_type_indexable = true; }; SCENARIO("IndexableTypes and VariadicIndexer", "[IndexableTypes][VariadicIndexer]") { @@ -63,14 +74,120 @@ SCENARIO("IndexableTypes and VariadicIndexer", "[IndexableTypes][VariadicIndexer REQUIRE(lRho == static_cast(2)); } } - WHEN("We use the Get functionality") { - // Request a type that exists, but an incorrect index - Real Zbar = Get(lambda, 2); - // Request a type that doesn't exist but an index that does - Real lRho = Get(lambda, 2); + WHEN("We use the SafeMustGet functionality") { + // Request a type that exists, but no integer index + const Real Zbar = SafeMustGet(lambda); + // Request a type that exists, but the wrong integer index + Real Zbar_1 = SafeMustGet(lambda, 2); THEN("We get the correct values") { REQUIRE(Zbar == static_cast(0)); - REQUIRE(lRho == static_cast(2)); + REQUIRE_THAT(Zbar, Catch::Matchers::WithinRel(Zbar_1, 1.0e-14)); + } + // This is probably a dumb check since Zbar_1 wasn't a Real* or a Real& + AND_THEN("We don't modify the lambda values by changing the local values") { + Zbar_1 = 5.3; + REQUIRE_THAT(Zbar, Catch::Matchers::WithinRel( + SafeMustGet(lambda), 1.0e-14)); + } + } + WHEN("We use the SafeMustSet functionality") { + // Request a type that exists, but no index + Real Zbar = 5.567; + SafeMustSet(lambda, Zbar); + // Request a type that exists, but the wrong integer index + Real lRho = 1.102; + SafeMustSet(lambda, 0, lRho); + THEN("We get the correct values") { + REQUIRE_THAT(Zbar, Catch::Matchers::WithinRel( + SafeMustGet(lambda), 1.0e-14)); + REQUIRE_THAT( + lRho, Catch::Matchers::WithinRel(SafeMustGet(lambda), 1.0e-14)); + } + } + WHEN("We use the SafeGet functionality") { + constexpr Real unmodified = -1.0; + Real destination = unmodified; + WHEN("We request a type that exists") { + const bool modified = SafeGet(lambda, 2, destination); + THEN("The destination value will be modified") { + CHECK(modified); + REQUIRE(destination == lambda[MeanIonizationState{}]); + } + } + WHEN("We request a type that doesn't exist") { + const bool modified = SafeGet(lambda, 2, destination); + THEN("The destination value will remain UNmodified") { + CHECK(!modified); + REQUIRE(destination == unmodified); + } + } + WHEN("A normal array-like lambda is used") { + std::array lambda_arr{1, 2, 3}; + constexpr size_t my_index = 2; + const bool modified = SafeGet(lambda_arr, my_index, destination); + THEN("The destination value will reflect the index from the array") { + CHECK(modified); + REQUIRE(destination == lambda_arr[my_index]); + } + } + } + WHEN("We use the SafeSet functionality") { + constexpr Real new_value = -1.0; + WHEN("We want to set a value for a type index that exists") { + const bool modified = SafeSet(lambda, 2, new_value); + THEN("The lambda index was modified") { + CHECK(modified); + REQUIRE(lambda[MeanIonizationState{}] == new_value); + } + } + WHEN("We want to set a value for a type index that doesn't exist") { + const bool modified = SafeSet(lambda, 2, new_value); + Lambda_t old_lambda; + for (std::size_t i = 0; i < Lambda_t::size(); ++i) { + old_lambda[i] = lambda[i]; + } + THEN("None of the lambda values were modified") { + CHECK(!modified); + for (std::size_t i = 0; i < Lambda_t::size(); ++i) { + INFO("i: " << i); + CHECK(lambda[i] == old_lambda[i]); + } + } + } + WHEN("A normal array-like lambda is used") { + std::array lambda_arr{4, 5, 6}; + constexpr size_t my_index = 1; + const bool modified = SafeSet(lambda_arr, my_index, new_value); + THEN("The lambda value at the appropriate index has been modified") { + CHECK(modified); + REQUIRE(lambda_arr[my_index] == new_value); + } + } + } + } + GIVEN("A normal array that does not support IndexableTypes") { + constexpr size_t num_lambda = 4; + std::array lambda{}; + for (size_t i = 0; i < num_lambda; i++) { + lambda[i] = static_cast(i); + } + WHEN("We use the SafeMustGet functionality") { + THEN("The type-based index is ignored and only the integer index is used") { + for (size_t i = 0; i < num_lambda; i++) { + INFO("i: " << i); + const Real val = SafeMustGet(lambda, i); + CHECK_THAT(lambda[i], Catch::Matchers::WithinRel(val, 1.0e-14)); + } + } + } + WHEN("We use the SafeMustSet functionality") { + THEN("The type-based index is ignored and only the integer index is used") { + for (size_t i = 0; i < num_lambda; i++) { + INFO("i: " << i); + const Real val = i * i; + SafeMustSet(lambda, i, val); + CHECK_THAT(lambda[i], Catch::Matchers::WithinRel(val, 1.0e-14)); + } } } } @@ -79,17 +196,66 @@ SCENARIO("IndexableTypes and VariadicIndexer", "[IndexableTypes][VariadicIndexer SCENARIO("IndexableTypes and ManualLambda", "[IndexableTypes]") { GIVEN("A manually written indexer, filled with indices 0, 1, 2") { ManualLambda_t lambda; - for (std::size_t i = 0; i < 3; ++i) { + for (std::size_t i = 0; i < lambda.length; ++i) { lambda[i] = static_cast(i); } - WHEN("We use the Get functionality") { - // Request a type that exists, but an incorrect index - Real Zbar = Get(lambda, 2); - // Request a type that doesn't exist but an index that does - Real lRho = Get(lambda, 2); - THEN("We get the correct values") { - REQUIRE(Zbar == static_cast(0)); - REQUIRE(lRho == static_cast(2)); + WHEN("We use the SafeGet functionality") { + constexpr Real unmodified = -1.0; + Real destination = unmodified; + WHEN("We request a type that doesn't exist in the manual indexer") { + constexpr size_t my_index = 1; + const bool modified = SafeGet(lambda, my_index, destination); + THEN("The destination WILL be modified since the manual indexer doesn't have the " + " `is_type_indexable` data member") { + CHECK(modified); + REQUIRE(destination == lambda[my_index]); + } + } + WHEN("We define a new manual indexer that has the `is_type_indexable` " + "data member") { + NewManualLambda_t lambda_new; + for (std::size_t i = 0; i < lambda.length; ++i) { + lambda_new[i] = lambda[i]; + } + WHEN("We request a type that doesn't exist in the manual indexer") { + destination = unmodified; + constexpr size_t my_index = 1; + const bool modified = SafeGet(lambda_new, my_index, destination); + THEN("The destination will NOT be modified") { + CHECK(!modified); + REQUIRE(destination == unmodified); + } + } + } + } + WHEN("We use the SafeSet functionality") { + constexpr Real new_value = -1.0; + WHEN("We want to set a value for a type index that doesn't exist") { + constexpr size_t my_index = 1; + const bool modified = SafeSet(lambda, my_index, new_value); + THEN("The lambda value WILL be modified since the manual indexer doesn't have " + "the `is_type_indexable` data member") { + CHECK(modified); + REQUIRE(lambda[my_index] == new_value); + } + } + WHEN("We define a new manual indexer that has the `is_type_indexable` data " + "member") { + NewManualLambda_t lambda_new; + for (std::size_t i = 0; i < lambda.length; ++i) { + lambda_new[i] = lambda[i]; + } + WHEN("We request a type that doesn't exist in the manual indexer") { + constexpr size_t my_index = 1; + const bool modified = SafeSet(lambda_new, my_index, new_value); + THEN("The lambda will NOT be modified") { + CHECK(!modified); + for (std::size_t i = 0; i < lambda.length; ++i) { + INFO("i: " << i); + CHECK(lambda_new[i] == lambda[i]); + } + } + } } } } diff --git a/test/test_pte_ideal.cpp b/test/test_pte_ideal.cpp index 31c91bfaa6d..f318472b251 100644 --- a/test/test_pte_ideal.cpp +++ b/test/test_pte_ideal.cpp @@ -51,8 +51,12 @@ using singularity::IndexableTypes::MeanIonizationState; struct LambdaIndexerSingle { PORTABLE_FORCEINLINE_FUNCTION Real &operator[](const int i) { return z; } + PORTABLE_INLINE_FUNCTION + const Real &operator[](const int i) const { return z; } PORTABLE_FORCEINLINE_FUNCTION Real &operator[](const MeanIonizationState &s) { return z; } + PORTABLE_INLINE_FUNCTION + const Real &operator[](const MeanIonizationState &s) const { return z; } Real z = 0.9; }; diff --git a/test/test_spiner_transform.cpp b/test/test_spiner_transform.cpp index 960f608114d..0baeedaddf6 100644 --- a/test/test_spiner_transform.cpp +++ b/test/test_spiner_transform.cpp @@ -102,9 +102,7 @@ struct TestDataContainer { const Real lRho = spiner_common::to_log(rho, lRhoOffset); const Real lE = spiner_common::to_log(sie, lEOffset); // If we wanted to mock the lambda write-back (optional) - if (!variadic_utils::is_nullptr(lambda)) { - IndexerUtils::Get(lambda, 0) = lRho; - } + IndexerUtils::SafeSet(lambda, 0, lRho); return db.interpToReal(lRho, lE); }