Skip to content

Enable mixed precision dispatch in distributed matrix with ScalarCache, GenericDenseCache, and GenericVectorCache #1819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 21, 2025
Merged
4 changes: 2 additions & 2 deletions benchmark/test/reference/distributed_solver.profile.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ DEBUG: end copy
DEBUG: begin copy
DEBUG: end copy
DEBUG: end copy(<typename>)
DEBUG: begin dense::fill
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because ScalarCache allocate the memory when it is first used with the type

DEBUG: end dense::fill
DEBUG: begin components::aos_to_soa
DEBUG: end components::aos_to_soa
DEBUG: begin distributed_matrix::separate_local_nonlocal
Expand Down Expand Up @@ -148,6 +146,8 @@ DEBUG: begin advanced_apply(<typename>)
DEBUG: begin csr::advanced_spmv
DEBUG: end csr::advanced_spmv
DEBUG: end advanced_apply(<typename>)
DEBUG: begin dense::fill
DEBUG: end dense::fill
DEBUG: begin advanced_apply(<typename>)
DEBUG: begin csr::advanced_spmv
DEBUG: end csr::advanced_spmv
Expand Down
4 changes: 2 additions & 2 deletions benchmark/test/reference/spmv_distributed.profile.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ DEBUG: end copy
DEBUG: begin copy
DEBUG: end copy
DEBUG: end copy(<typename>)
DEBUG: begin dense::fill
DEBUG: end dense::fill
DEBUG: begin components::aos_to_soa
DEBUG: end components::aos_to_soa
DEBUG: begin distributed_matrix::separate_local_nonlocal
Expand Down Expand Up @@ -128,6 +126,8 @@ DEBUG: begin apply(<typename>)
DEBUG: begin csr::spmv
DEBUG: end csr::spmv
DEBUG: end apply(<typename>)
DEBUG: begin dense::fill
DEBUG: end dense::fill
DEBUG: begin advanced_apply(<typename>)
DEBUG: begin csr::advanced_spmv
DEBUG: end csr::advanced_spmv
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test/reference/spmv_distributed.profile.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"comm_pattern": "stencil",
"spmv": {
"csr-csr": {
"storage": 11476,
"storage": 11452,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is from DenseCache -> ScalarCache, we allocate the memory when using it.
It is 3 * sizeof(ValueType) because we use 3 processes for mpi

"time": 1.0,
"repetitions": 1,
"completed": true
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test/reference/spmv_distributed.simple.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"comm_pattern": "stencil",
"spmv": {
"csr-csr": {
"storage": 11476,
"storage": 11452,
"max_relative_norm2": 1.0,
"time": 1.0,
"repetitions": 10,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"comm_pattern": "stencil",
"spmv": {
"csr-csr": {
"storage": 17300,
"storage": 17252,
"max_relative_norm2": 1.0,
"time": 1.0,
"repetitions": 10,
Expand Down
132 changes: 131 additions & 1 deletion core/base/dense_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "ginkgo/core/base/dense_cache.hpp"

#include <memory>
#include <string>

#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/dim.hpp>
#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "core/base/dense_cache_accessor.hpp"

namespace gko {
namespace detail {
Expand All @@ -32,9 +39,132 @@ void DenseCache<ValueType>::init_from(
}


const array<char>& GenericDenseCacheAccessor::get_workspace(
const GenericDenseCache& cache)
{
return cache.workspace;
}


GenericDenseCache::GenericDenseCache(const GenericDenseCache&) {}


GenericDenseCache::GenericDenseCache(GenericDenseCache&&) noexcept {}


GenericDenseCache& GenericDenseCache::operator=(const GenericDenseCache&)
{
return *this;
}


GenericDenseCache& GenericDenseCache::operator=(GenericDenseCache&&) noexcept
{
return *this;
}


template <typename ValueType>
std::shared_ptr<matrix::Dense<ValueType>> GenericDenseCache::get(
std::shared_ptr<const Executor> exec, dim<2> size) const
{
if (exec != workspace.get_executor() ||
size[0] * size[1] * sizeof(ValueType) > workspace.get_size()) {
auto new_workspace =
gko::array<char>(exec, size[0] * size[1] * sizeof(ValueType));
// We use swap here, otherwise array copy/move between different
// executor will keep the original executor.
std::swap(workspace, new_workspace);
}
return matrix::Dense<ValueType>::create(
exec, size,
make_array_view(exec, size[0] * size[1],
reinterpret_cast<ValueType*>(workspace.get_data())),
size[1]);
}


std::shared_ptr<const Executor> ScalarCacheAccessor::get_executor(
const ScalarCache& cache)
{
return cache.exec;
}


double ScalarCacheAccessor::get_value(const ScalarCache& cache)
{
return cache.value;
}


const std::map<std::string, std::shared_ptr<const gko::LinOp>>&
ScalarCacheAccessor::get_scalars(const ScalarCache& cache)
{
return cache.scalars;
}


ScalarCache::ScalarCache(std::shared_ptr<const Executor> executor,
double scalar_value)
: exec(std::move(executor)), value(scalar_value){};

ScalarCache::ScalarCache(const ScalarCache& other) { *this = other; }


ScalarCache::ScalarCache(ScalarCache&& other) noexcept
{
*this = std::move(other);
}


ScalarCache& ScalarCache::operator=(const ScalarCache& other)
{
exec = other.exec;
value = other.value;
return *this;
}


ScalarCache& ScalarCache::operator=(ScalarCache&& other) noexcept
{
exec = std::exchange(other.exec, nullptr);
value = std::exchange(other.value, 0.0);
other.scalars.clear();
return *this;
}


template <typename ValueType>
std::shared_ptr<const matrix::Dense<ValueType>> ScalarCache::get() const
{
// using typeid name as key
std::string value_string = typeid(ValueType).name();
auto search = scalars.find(value_string);
if (search != scalars.end()) {
return std::dynamic_pointer_cast<const matrix::Dense<ValueType>>(
search->second);
} else {
auto new_scalar =
share(matrix::Dense<ValueType>::create(exec, dim<2>{1, 1}));
new_scalar->fill(static_cast<ValueType>(value));
scalars[value_string] = new_scalar;
return new_scalar;
}
}


#define GKO_DECLARE_DENSE_CACHE(_type) struct DenseCache<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CACHE);

#define GKO_DECLARE_GENERIC_DENSE_CACHE_GET(_type) \
std::shared_ptr<matrix::Dense<_type>> GenericDenseCache::get<_type>( \
std::shared_ptr<const Executor>, dim<2>) const
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GENERIC_DENSE_CACHE_GET);

#define GKO_DECLARE_SCALAR_CACHE_GET(_type) \
std::shared_ptr<const matrix::Dense<_type>> ScalarCache::get<_type>() const
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_SCALAR_CACHE_GET);


} // namespace detail
} // namespace gko
50 changes: 50 additions & 0 deletions core/base/dense_cache_accessor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_BASE_DENSE_CACHE_ACCESSOR_HPP_
#define GKO_CORE_BASE_DENSE_CACHE_ACCESSOR_HPP_


#include <map>
#include <string>

#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/base/lin_op.hpp>


namespace gko {
namespace detail {


// helper to access private member for testing
class GenericDenseCacheAccessor {
public:
// access to the workspace
static const array<char>& get_workspace(const GenericDenseCache& cache);
};


// helper to access private member for testing
class ScalarCacheAccessor {
public:
// access to the executor
static std::shared_ptr<const Executor> get_executor(
const ScalarCache& cache);

// access to the value
static double get_value(const ScalarCache& cache);

// access to the scalars
static const std::map<std::string, std::shared_ptr<const gko::LinOp>>&
get_scalars(const ScalarCache& cache);
};


} // namespace detail
} // namespace gko


#endif // GKO_CORE_BASE_DENSE_CACHE_ACCESSOR_HPP_
Loading
Loading