Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5a492fd
Add ability to check if a type accepts an whole number index
jhp-lanl Sep 12, 2025
02a7396
Add safeSet() and safeGet() helpers
jhp-lanl Sep 12, 2025
8895d80
Add safeGet and safeSet that don't take numerical indices
jhp-lanl Sep 12, 2025
58a920b
Switch to safeGet and safeSet for indexable types
jhp-lanl Sep 12, 2025
fce8f7c
Whoops... forgot comment
jhp-lanl Sep 12, 2025
4959500
Clang format
jhp-lanl Sep 12, 2025
c2b1dbb
Update changelog
jhp-lanl Sep 12, 2025
589e1dc
Update doc
jhp-lanl Sep 12, 2025
fd74c25
Make functions PORTABLE and add required get/set
jhp-lanl Sep 12, 2025
a3d71fc
Make code more DRY and rename things a bit
jhp-lanl Sep 13, 2025
856531f
Add int index check
jhp-lanl Sep 13, 2025
220df73
Rename safeGet/Set to SafeGet/Set and remove direct indexing or regular
jhp-lanl Sep 13, 2025
271e52e
Make indexer const correct
jhp-lanl Sep 13, 2025
63d5a39
Rename safeGet/Set
jhp-lanl Sep 13, 2025
8a893a7
Clang format
jhp-lanl Sep 13, 2025
d6a9f68
Switch to template-based decision to use integer index or not
jhp-lanl Sep 13, 2025
cf8bf93
Whoops... forgot to return
jhp-lanl Sep 13, 2025
2d4153a
Clang format
jhp-lanl Sep 13, 2025
0a4baa1
Add docs for SafeMustGet() and SafeMustSet()
jhp-lanl Sep 13, 2025
4c89631
Get rid of Get and have wrappers use GetSet. Also update comments and…
jhp-lanl Sep 17, 2025
9d52960
Switch Get for Safe versions and expand tests
jhp-lanl Sep 17, 2025
69b56ab
Merge branch 'main' of github.com:lanl/singularity-eos into jhp/spine…
jhp-lanl Sep 17, 2025
0feb8e4
Remove last Get in favor of SafeSet
jhp-lanl Sep 17, 2025
c245728
Remove a few more instances of Get in favor of the Safe variety for I…
jhp-lanl Sep 17, 2025
734d55b
Clang format
jhp-lanl Sep 17, 2025
efd9a9f
Whoops... void was a bad choice for a type index
jhp-lanl Sep 17, 2025
d8dbccc
Let's try an unreachable return to make decltype(auto) happy
jhp-lanl Sep 17, 2025
bfacaac
Add dependent_false_v for if constexpr static asserts
jhp-lanl Sep 17, 2025
36d4286
Move throw into SafeGet/SafeSet wrappers and provide helpful compile …
jhp-lanl Sep 17, 2025
c783623
Can't test for a runtime throw when compile-time error will be used
jhp-lanl Sep 17, 2025
232ecca
Small doc tweaks
jhp-lanl Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

### Fixed (Repair bugs, etc)
- [[PR561]](https://github.com/lanl/singularity-eos/pull/561) Fix logic for kokkos-kernels in spackage so that it is only required for closure models on GPU
- [[PR564]](https://github.com/lanl/singularity-eos/pull/564) Fix logic for numerical vs type indices by adding safeGet() and safeSet() helpers

### Changed (changing behavior/API/variables/...)

Expand Down
50 changes: 44 additions & 6 deletions doc/sphinx/src/using-eos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ of the ``[]`` operator that takes your type. For example:

class MyLambda_t {
public:
// Enable recognition that this is type-indexable
constexpr static bool is_type_indexable = true;
MyLambda_t() = default;
PORTABLE_FORCEINLINE_FUNCTION
Real &operator[](const std::size_t idx) const {
Expand All @@ -724,17 +726,53 @@ which might be used as
where ``MeanIonizationState`` is shorthand for index 2, since you
defined that overload. Note that the ``operator[]`` must be marked
``const``. To more easily enable mixing and matching integer-based
indexing with type-based indexing, the function
indexing with type-based indexing, the functions

.. code-block:: cpp

template<typename Name_t, typename Indexer_t>
template <typename T, typename Indexer_t>
PORTABLE_FORCEINLINE_FUNCTION
Real &Get(Indexer_t &&lambda, std::size_t idx = 0);
inline bool safeGet(Indexer_t const &lambda, std::size_t const idx, Real &out);

will return a reference to the value at named index ``Name_t()`` if
that overload is defined in ``Indexer_t`` and otherwise return a
reference at index ``idx``.
.. code-block:: cpp

template <typename T, typename Indexer_t>
PORTABLE_FORCEINLINE_FUNCTION
inline bool safeGet(Indexer_t const &lambda, Real &out);

will update the value of ``out`` with the value at either the appropriate
type-based index, ``T``, or the numerical index, ``idx``, if the ``Indexer_t``
doesn't accept type-based indexing. If the ``Indexer_t`` **does** accept
type-based indexing but **doesn't** have the requested index, then the
``out`` value is not updated. The same is true for when ``Indexer_t`` is the
``nullptr``. The overload that doesn't take a numerical index will *only*
return the value at a type-based index.

Similarly, the functions

.. code-block:: cpp

template <typename T, typename Indexer_t>
PORTABLE_FORCEINLINE_FUNCTION
inline bool safeSet(Indexer_t &lambda, std::size_t const idx, Real const in);

.. code-block:: cpp

template <typename T, typename Indexer_t>
PORTABLE_FORCEINLINE_FUNCTION
inline bool safeSet(Indexer_t &lambda, Real const in)

can modify the values in the ``Indexer_t`` and behave the same way. In this way,
if a type-based index isn't present in the container, then another index won't
be overwritten.

.. note::

Previous versions defined a ``Get()`` function that was "unsafe" in the
sense that it would fall back on the numerical index even if a type-based
indexer was used. This could result in retrieving and overwriting incorrect
values in the indexer. We recommend not using this function and instead using
the "safe" versions.

As a convenience tool, the struct

Expand Down
147 changes: 145 additions & 2 deletions singularity-eos/base/indexable_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,131 @@

namespace singularity {
namespace IndexerUtils {
// Convenience function for accessing an indexer by either type or
// natural number index depending on what is available

// Identifies an indexer as a type-based indexer
template <class, class = void>
struct is_type_indexer : std::false_type {};
template <class Indexer_t>
struct is_type_indexer<Indexer_t,
std::void_t<decltype(std::decay_t<Indexer_t>::is_type_indexable)>>
: std::bool_constant<std::decay_t<Indexer_t>::is_type_indexable> {};
template <class Indexer_t>
constexpr bool is_type_indexer_v = is_type_indexer<Indexer_t>::value;

// The "safe" version of Get(). This function will ONLY return a value IF that
// type-based index is present in the Indexer OR if the Indexer doesn't support
// type-based indexing.
template <typename T, typename Indexer_t>
inline bool safeGet(Indexer_t const &lambda, std::size_t const idx, Real &out) {
// If null then nothing happens
if (variadic_utils::is_nullptr(lambda)) {
return false;
}

// Return value if type index is available
if constexpr (variadic_utils::is_indexable_v<Indexer_t, T>) {
out = lambda[T{}];
return true;
}

// Do nothing if lambda has type indexing BUT doesn't have this type index
if constexpr (is_type_indexer_v<Indexer_t>) {
return false;
}

// Fall back to numerical indexing if no type indexing
if constexpr (variadic_utils::has_whole_num_index<Indexer_t>::value) {
out = lambda[idx];
return true;
}

// Something else...
return false;
}

// Overload when no index is provided
template <typename T, typename Indexer_t>
inline bool safeGet(Indexer_t const &lambda, Real &out) {
// If null then nothing happens
if (variadic_utils::is_nullptr(lambda)) {
return false;
}

// Return value if type index is available
if constexpr (variadic_utils::is_indexable_v<Indexer_t, T>) {
out = lambda[T{}];
return true;
}

// Do nothing if lambda has type indexing BUT doesn't have this type index
if constexpr (is_type_indexer_v<Indexer_t>) {
return false;
}

// Something else...
return false;
}

// Break out "Set" functionality from "Get". The original "Get()" did both, but
// the "safe" version needs to separate that functionality for setting the
// values in a lambda
template <typename T, typename Indexer_t>
inline bool safeSet(Indexer_t &lambda, std::size_t const idx, Real const in) {
// If null then nothing happens
if (variadic_utils::is_nullptr(lambda)) {
return false;
}

// Return value if type index is available
if constexpr (variadic_utils::is_indexable_v<Indexer_t, T>) {
lambda[T{}] = in;
return true;
}

// Do nothing if lambda has type indexing BUT doesn't have this type index
if constexpr (is_type_indexer_v<Indexer_t>) {
return false;
}

// Fall back to numerical indexing if no type indexing
if constexpr (variadic_utils::has_whole_num_index<Indexer_t>::value) {
lambda[idx] = in;
return true;
}

// Something else...
return false;
}

// Overload without numeric index
template <typename T, typename Indexer_t>
inline bool safeSet(Indexer_t &lambda, Real const in) {
// If null then nothing happens
if (variadic_utils::is_nullptr(lambda)) {
return false;
}

// Return value if type index is available
if constexpr (variadic_utils::is_indexable_v<Indexer_t, T>) {
lambda[T{}] = in;
return true;
}

// Do nothing if lambda has type indexing BUT doesn't have this type index
if constexpr (is_type_indexer_v<Indexer_t>) {
return false;
}

// Something else...
return false;
}

// NOTE: this Get is "unsafe" because it can allow you to overwrite a type-based
// index since it automatically falls back to numeric indexing if the type
// index isn't present.

// Convenience function for accessing an indexer by either type or natural
// number index depending on what is available.
template <typename T, typename Indexer_t>
PORTABLE_FORCEINLINE_FUNCTION auto &Get(Indexer_t &&lambda, std::size_t idx = 0) {
if constexpr (variadic_utils::is_indexable_v<Indexer_t, T>) {
Expand All @@ -40,25 +163,44 @@ PORTABLE_FORCEINLINE_FUNCTION auto &Get(Indexer_t &&lambda, std::size_t idx = 0)
template <typename Data_t, typename... Ts>
class VariadicIndexerBase {
public:
// Any class that wants to be recognized as indexable (so that we don't
// accidentally fall back to integer indexing when we don't want to) needs to
// include this.
constexpr static bool is_type_indexable = true;

// JHP: another option for the `is_type_indexable` flag is to take the ADL
// route. Essentially this would involve defining a friend function that
// could be defined in an appropriate namesapce so that theoretically a host
// code could use a TPL container with type-based indexing and allow that
// container to be flagged in our code as acceptable. This seems like a bit
// of a heavy hammer for what we need here though. We can easily change this
// if a TPL provides a type that is being used for this purpose.

VariadicIndexerBase() = default;

PORTABLE_FORCEINLINE_FUNCTION
VariadicIndexerBase(const Data_t &data) : data_(data) {}

template <typename T,
typename = std::enable_if_t<variadic_utils::contains<T, Ts...>::value>>
PORTABLE_FORCEINLINE_FUNCTION Real &operator[](const T &t) {
constexpr std::size_t idx = variadic_utils::GetIndexInTL<T, Ts...>();
return data_[idx];
}

PORTABLE_FORCEINLINE_FUNCTION
Real &operator[](const std::size_t idx) { return data_[idx]; }

template <typename T,
typename = std::enable_if_t<variadic_utils::contains<T, Ts...>::value>>
PORTABLE_FORCEINLINE_FUNCTION const Real &operator[](const T &t) const {
constexpr std::size_t idx = variadic_utils::GetIndexInTL<T, Ts...>();
return data_[idx];
}

PORTABLE_FORCEINLINE_FUNCTION
const Real &operator[](const std::size_t idx) const { return data_[idx]; }

static inline constexpr std::size_t size() { return sizeof...(Ts); }

private:
Expand All @@ -70,6 +212,7 @@ using VariadicIndexer = VariadicIndexerBase<std::array<Real, sizeof...(Ts)>, Ts.
// uses a Real*
template <typename... Ts>
using VariadicPointerIndexer = VariadicIndexerBase<Real *, Ts...>;

} // namespace IndexerUtils

namespace IndexableTypes {
Expand Down
10 changes: 10 additions & 0 deletions singularity-eos/base/variadic_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ struct is_indexable<T, Index,
template <typename T, typename Index>
constexpr bool is_indexable_v = is_indexable<T, Index>::value;

// Check if a type can accept a size_t index
template <class T, class = void>
struct has_whole_num_index : std::false_type {};
template <class T>
struct has_whole_num_index<
T, std::void_t<decltype(std::declval<T>()[std::declval<std::size_t>()])>>
: std::true_type {};
template <typename T>
constexpr bool has_whole_num_index_v = has_whole_num_index<T>::value;

// this flattens a typelist of typelists to a single typelist

// first parameter - accumulator
Expand Down
30 changes: 9 additions & 21 deletions singularity-eos/eos/eos_spiner_rho_sie.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,7 @@ PORTABLE_INLINE_FUNCTION void SpinerEOSDependsRhoSieTransformable<TransformerT>:
}
} else {
lRho = to_log(rho, lRhoOffset_);
if (!variadic_utils::is_nullptr(lambda)) {
IndexerUtils::Get<IndexableTypes::LogDensity>(lambda, Lambda::lRho) = lRho;
}
IndexerUtils::safeSet<IndexableTypes::LogDensity>(lambda, Lambda::lRho, lRho);
}
if (output & thermalqs::temperature) {

Expand Down Expand Up @@ -747,9 +745,7 @@ SpinerEOSDependsRhoSieTransformable<TransformerT>::interpRhoT_(const Real rho,
Indexer_t &&lambda) const {
const Real lRho = spiner_common::to_log(rho, lRhoOffset_);
const Real lT = spiner_common::to_log(T, lTOffset_);
if (!variadic_utils::is_nullptr(lambda)) {
IndexerUtils::Get<IndexableTypes::LogDensity>(lambda, Lambda::lRho) = lRho;
}
IndexerUtils::safeSet<IndexableTypes::LogDensity>(lambda, Lambda::lRho, lRho);
return db.interpToReal(lRho, lT);
}

Expand All @@ -761,9 +757,7 @@ SpinerEOSDependsRhoSieTransformable<TransformerT>::interpRhoSie_(
const Real lRho = spiner_common::to_log(rho, lRhoOffset_);
const Real sie_transformed = transformer_.transform(sie, rho);
const Real lE = spiner_common::to_log(sie_transformed, lEOffset_);
if (!variadic_utils::is_nullptr(lambda)) {
IndexerUtils::Get<IndexableTypes::LogDensity>(lambda, Lambda::lRho) = lRho;
}
IndexerUtils::safeSet<IndexableTypes::LogDensity>(lambda, Lambda::lRho, lRho);
return db.interpToReal(lRho, lE);
}

Expand All @@ -786,12 +780,10 @@ SpinerEOSDependsRhoSieTransformable<TransformerT>::lRhoFromPlT_(
}
} else {
Real lRhoGuess = reproducible_ ? lRhoMin_ : 0.5 * (lRhoMin_ + lRhoMax_);
if (!variadic_utils::is_nullptr(lambda)) {
Real lRho_cache =
IndexerUtils::Get<IndexableTypes::LogDensity>(lambda, Lambda::lRho);
if ((lRhoMin_ <= lRho_cache) && (lRho_cache <= lRhoMax_)) {
lRhoGuess = lRho_cache;
}
Real lRho_cache;
IndexerUtils::safeGet<IndexableTypes::LogDensity>(lambda, Lambda::lRho, lRho_cache);
if ((lRhoMin_ <= lRho_cache) && (lRho_cache <= lRhoMax_)) {
lRhoGuess = lRho_cache;
}
const callable_interp::l_interp PFunc(dependsRhoT_.P, lT);
status = SP_ROOT_FINDER(PFunc, P, lRhoGuess, lRhoMin_, lRhoMax_, robust::EPS(),
Expand All @@ -809,12 +801,8 @@ SpinerEOSDependsRhoSieTransformable<TransformerT>::lRhoFromPlT_(
lRho = reproducible_ ? lRhoMin_ : lRhoGuess;
}
}
if (!variadic_utils::is_nullptr(lambda)) {
IndexerUtils::Get<IndexableTypes::LogDensity>(lambda, Lambda::lRho) = lRho;
if constexpr (variadic_utils::is_indexable_v<Indexer_t, IndexableTypes::RootStatus>) {
lambda[IndexableTypes::RootStatus()] = static_cast<Real>(status);
}
}
IndexerUtils::safeSet<IndexableTypes::LogDensity>(lambda, Lambda::lRho, lRho);
IndexerUtils::safeSet<IndexableTypes::RootStatus>(lambda, static_cast<Real>(status));
return lRho;
}

Expand Down
Loading