Skip to content

Commit 95793ae

Browse files
committed
add ScalarCache to generate different Dense scalar with different value type
1 parent a4b02e2 commit 95793ae

File tree

5 files changed

+240
-21
lines changed

5 files changed

+240
-21
lines changed

core/base/dense_cache.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44

55
#include "ginkgo/core/base/dense_cache.hpp"
66

7+
#include <memory>
8+
#include <string>
9+
10+
#include <ginkgo/core/base/array.hpp>
11+
#include <ginkgo/core/base/dim.hpp>
12+
#include <ginkgo/core/base/executor.hpp>
713
#include <ginkgo/core/matrix/dense.hpp>
814

915

@@ -70,6 +76,55 @@ std::shared_ptr<matrix::Dense<ValueType>> GenericDenseCache::get(
7076
}
7177

7278

79+
ScalarCache::ScalarCache(std::shared_ptr<const Executor> executor,
80+
double scalar_value)
81+
: exec(std::move(executor)), value(scalar_value){};
82+
83+
ScalarCache::ScalarCache(const ScalarCache& other) { *this = other; }
84+
85+
86+
ScalarCache::ScalarCache(ScalarCache&& other) noexcept
87+
{
88+
*this = std::move(other);
89+
}
90+
91+
92+
ScalarCache& ScalarCache::operator=(const ScalarCache& other)
93+
{
94+
exec = other.exec;
95+
value = other.value;
96+
return *this;
97+
}
98+
99+
100+
ScalarCache& ScalarCache::operator=(ScalarCache&& other) noexcept
101+
{
102+
exec = std::exchange(other.exec, nullptr);
103+
value = std::exchange(other.value, 0.0);
104+
other.scalars.clear();
105+
return *this;
106+
}
107+
108+
109+
template <typename ValueType>
110+
std::shared_ptr<const matrix::Dense<ValueType>> ScalarCache::get() const
111+
{
112+
// using typeid name as key
113+
std::string value_string = typeid(ValueType).name();
114+
auto search = scalars.find(value_string);
115+
if (search != scalars.end()) {
116+
return std::dynamic_pointer_cast<const matrix::Dense<ValueType>>(
117+
search->second);
118+
} else {
119+
auto new_scalar =
120+
share(matrix::Dense<ValueType>::create(exec, dim<2>{1, 1}));
121+
new_scalar->fill(static_cast<ValueType>(value));
122+
scalars[value_string] = new_scalar;
123+
return new_scalar;
124+
}
125+
}
126+
127+
73128
#define GKO_DECLARE_DENSE_CACHE(_type) struct DenseCache<_type>
74129
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CACHE);
75130

@@ -78,6 +133,10 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CACHE);
78133
std::shared_ptr<const Executor>, dim<2>) const
79134
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GENERIC_DENSE_CACHE_GET);
80135

136+
#define GKO_DECLARE_SCALAR_CACHE_GET(_type) \
137+
std::shared_ptr<const matrix::Dense<_type>> ScalarCache::get<_type>() const
138+
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_SCALAR_CACHE_GET);
139+
81140

82141
} // namespace detail
83142
} // namespace gko

core/distributed/matrix.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
5050
DistributedBase{comm},
5151
row_gatherer_{RowGatherer<LocalIndexType>::create(exec, comm)},
5252
imap_{exec},
53-
one_scalar_{},
53+
one_scalar_{exec, 1.0},
5454
local_mtx_{local_matrix_template->clone(exec)},
5555
non_local_mtx_{non_local_matrix_template->clone(exec)}
5656
{
@@ -60,8 +60,6 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
6060
GKO_ASSERT(
6161
(dynamic_cast<ReadableFromMatrixData<ValueType, LocalIndexType>*>(
6262
non_local_mtx_.get())));
63-
one_scalar_.init(exec, dim<2>{1, 1});
64-
one_scalar_->fill(one<value_type>());
6563
}
6664

6765
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
@@ -72,13 +70,11 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
7270
DistributedBase{comm},
7371
row_gatherer_{RowGatherer<LocalIndexType>::create(exec, comm)},
7472
imap_{exec},
75-
one_scalar_{},
73+
one_scalar_{exec, 1.0},
7674
non_local_mtx_(::gko::matrix::Coo<ValueType, LocalIndexType>::create(
7775
exec, dim<2>{local_linop->get_size()[0], 0}))
7876
{
7977
this->set_size(size);
80-
one_scalar_.init(exec, dim<2>{1, 1});
81-
one_scalar_->fill(one<value_type>());
8278
local_mtx_ = std::move(local_linop);
8379
}
8480

@@ -91,13 +87,11 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
9187
DistributedBase{comm},
9288
row_gatherer_(RowGatherer<LocalIndexType>::create(exec, comm)),
9389
imap_(std::move(imap)),
94-
one_scalar_{}
90+
one_scalar_{exec, 1.0}
9591
{
9692
this->set_size({imap_.get_global_size(), imap_.get_global_size()});
9793
local_mtx_ = std::move(local_linop);
9894
non_local_mtx_ = std::move(non_local_linop);
99-
one_scalar_.init(exec, dim<2>{1, 1});
100-
one_scalar_->fill(one<value_type>());
10195

10296
row_gatherer_ = RowGatherer<LocalIndexType>::create(
10397
row_gatherer_->get_executor(),
@@ -427,8 +421,9 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
427421
if (recv_ptr != recv_buffer.get()) {
428422
recv_buffer->copy_from(host_recv_buffer.get());
429423
}
430-
non_local_mtx_->apply(one_scalar_.get(), recv_buffer.get(),
431-
one_scalar_.get(), local_x);
424+
non_local_mtx_->apply(
425+
one_scalar_.template get<ValueType>().get(), recv_buffer.get(),
426+
one_scalar_.template get<x_value_type>().get(), local_x);
432427
},
433428
b, x);
434429
}
@@ -479,8 +474,9 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
479474
if (recv_ptr != recv_buffer.get()) {
480475
recv_buffer->copy_from(host_recv_buffer.get());
481476
}
482-
non_local_mtx_->apply(local_alpha.get(), recv_buffer.get(),
483-
one_scalar_.get(), local_x);
477+
non_local_mtx_->apply(
478+
local_alpha.get(), recv_buffer.get(),
479+
one_scalar_.template get<x_value_type>().get(), local_x);
484480
},
485481
b, x);
486482
}
@@ -573,7 +569,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(const Matrix& other)
573569
DistributedBase{other.get_communicator()},
574570
row_gatherer_{RowGatherer<LocalIndexType>::create(
575571
other.get_executor(), other.get_communicator())},
576-
imap_(other.get_executor())
572+
imap_(other.get_executor()),
573+
one_scalar_(other.get_executor(), 1.0)
577574
{
578575
*this = other;
579576
}
@@ -587,7 +584,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
587584
DistributedBase{other.get_communicator()},
588585
row_gatherer_{RowGatherer<LocalIndexType>::create(
589586
other.get_executor(), other.get_communicator())},
590-
imap_(other.get_executor())
587+
imap_(other.get_executor()),
588+
one_scalar_(other.get_executor(), 1.0)
591589
{
592590
*this = std::move(other);
593591
}
@@ -606,8 +604,6 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(
606604
non_local_mtx_->copy_from(other.non_local_mtx_);
607605
row_gatherer_->copy_from(other.row_gatherer_);
608606
imap_ = other.imap_;
609-
one_scalar_.init(this->get_executor(), dim<2>{1, 1});
610-
one_scalar_->fill(one<value_type>());
611607
}
612608
return *this;
613609
}
@@ -626,8 +622,6 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(Matrix&& other)
626622
non_local_mtx_->move_from(other.non_local_mtx_);
627623
row_gatherer_->move_from(other.row_gatherer_);
628624
imap_ = std::move(other.imap_);
629-
one_scalar_.init(this->get_executor(), dim<2>{1, 1});
630-
one_scalar_->fill(one<value_type>());
631625
}
632626
return *this;
633627
}

core/test/base/dense_cache.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,129 @@ TYPED_TEST(GenericDenseCache, WorkspaceIsNotMoveAssigned)
362362
ASSERT_EQ(cache.workspace.get_size(), 0);
363363
ASSERT_EQ(cache.workspace.get_executor(), nullptr);
364364
}
365+
366+
367+
template <typename ValueType>
368+
class ScalarCache : public ::testing::Test {
369+
protected:
370+
using value_type = ValueType;
371+
372+
ScalarCache()
373+
: ref(gko::ReferenceExecutor::create()), value(1.0), cache(ref, value)
374+
{}
375+
376+
std::shared_ptr<gko::ReferenceExecutor> ref;
377+
double value;
378+
gko::detail::ScalarCache cache;
379+
};
380+
381+
TYPED_TEST_SUITE(ScalarCache, gko::test::ValueTypes, TypenameNameGenerator);
382+
383+
384+
TYPED_TEST(ScalarCache, CanInitWithExecutorAndValue)
385+
{
386+
using value_type = typename TestFixture::value_type;
387+
388+
gko::detail::ScalarCache cache(this->ref, 1.0);
389+
390+
ASSERT_EQ(this->cache.exec, this->ref);
391+
ASSERT_EQ(this->cache.value, 1.0);
392+
ASSERT_EQ(this->cache.scalars.size(), 0);
393+
}
394+
395+
396+
TYPED_TEST(ScalarCache, CanGetScalar)
397+
{
398+
using value_type = typename TestFixture::value_type;
399+
400+
auto scalar = this->cache.template get<value_type>();
401+
402+
ASSERT_NE(scalar, nullptr);
403+
GKO_ASSERT_EQUAL_DIMENSIONS(scalar->get_size(), gko::dim<2>(1, 1));
404+
ASSERT_EQ(scalar->at(0, 0), static_cast<value_type>(this->value));
405+
ASSERT_EQ(this->cache.scalars.size(), 1);
406+
}
407+
408+
409+
TYPED_TEST(ScalarCache, CanGetScalarWithDifferentType)
410+
{
411+
using value_type = typename TestFixture::value_type;
412+
using another_type = gko::next_precision<value_type>;
413+
414+
auto scalar = this->cache.template get<value_type>();
415+
auto another_scalar = this->cache.template get<another_type>();
416+
417+
// The original one is still valid
418+
ASSERT_NE(scalar, nullptr);
419+
GKO_ASSERT_EQUAL_DIMENSIONS(scalar->get_size(), gko::dim<2>(1, 1));
420+
ASSERT_EQ(scalar->at(0, 0), static_cast<value_type>(this->value));
421+
ASSERT_NE(another_scalar, nullptr);
422+
GKO_ASSERT_EQUAL_DIMENSIONS(another_scalar->get_size(), gko::dim<2>(1, 1));
423+
ASSERT_EQ(another_scalar->at(0, 0), static_cast<another_type>(this->value));
424+
// have two for different value type now
425+
ASSERT_EQ(this->cache.scalars.size(), 2);
426+
}
427+
428+
429+
TYPED_TEST(ScalarCache, VectorIsNotCopied)
430+
{
431+
using value_type = typename TestFixture::value_type;
432+
auto scalar = this->cache.template get<value_type>();
433+
434+
gko::detail::ScalarCache cache(this->cache);
435+
436+
ASSERT_EQ(cache.scalars.size(), 0);
437+
ASSERT_EQ(cache.value, this->cache.value);
438+
ASSERT_EQ(cache.exec, this->cache.exec);
439+
ASSERT_EQ(this->cache.scalars.size(), 1);
440+
}
441+
442+
443+
TYPED_TEST(ScalarCache, VectorIsNotMoved)
444+
{
445+
using value_type = typename TestFixture::value_type;
446+
auto scalar = this->cache.template get<value_type>();
447+
448+
gko::detail::ScalarCache cache(std::move(this->cache));
449+
450+
ASSERT_EQ(cache.scalars.size(), 0);
451+
ASSERT_EQ(cache.value, this->value);
452+
ASSERT_EQ(cache.exec, this->ref);
453+
// The original one is cleared
454+
ASSERT_EQ(this->cache.value, 0.0);
455+
ASSERT_EQ(this->cache.exec, nullptr);
456+
ASSERT_EQ(this->cache.scalars.size(), 0);
457+
}
458+
459+
460+
TYPED_TEST(ScalarCache, VectorIsNotCopyAssigned)
461+
{
462+
using value_type = typename TestFixture::value_type;
463+
auto scalar = this->cache.template get<value_type>();
464+
gko::detail::ScalarCache cache(this->ref, 2.0);
465+
466+
cache = this->cache;
467+
468+
ASSERT_EQ(cache.scalars.size(), 0);
469+
ASSERT_EQ(cache.value, this->cache.value);
470+
ASSERT_EQ(cache.exec, this->cache.exec);
471+
ASSERT_EQ(this->cache.scalars.size(), 1);
472+
}
473+
474+
475+
TYPED_TEST(ScalarCache, VectorIsNotMoveAssigned)
476+
{
477+
using value_type = typename TestFixture::value_type;
478+
auto scalar = this->cache.template get<value_type>();
479+
gko::detail::ScalarCache cache(this->ref, 2.0);
480+
481+
cache = std::move(this->cache);
482+
483+
ASSERT_EQ(cache.scalars.size(), 0);
484+
ASSERT_EQ(cache.value, this->value);
485+
ASSERT_EQ(cache.exec, this->ref);
486+
// The original one is cleared
487+
ASSERT_EQ(this->cache.value, 0.0);
488+
ASSERT_EQ(this->cache.exec, nullptr);
489+
ASSERT_EQ(this->cache.scalars.size(), 0);
490+
}

include/ginkgo/core/base/dense_cache.hpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,22 @@
66
#define GKO_PUBLIC_CORE_BASE_DENSE_CACHE_HPP_
77

88

9+
#include <map>
910
#include <memory>
11+
#include <string>
1012

1113
#include <ginkgo/core/base/array.hpp>
1214
#include <ginkgo/core/base/dim.hpp>
13-
#include <ginkgo/core/base/executor.hpp>
1415

1516

1617
namespace gko {
18+
19+
20+
class Executor;
21+
22+
class LinOp;
23+
24+
1725
namespace matrix {
1826

1927

@@ -122,6 +130,38 @@ struct GenericDenseCache {
122130
};
123131

124132

133+
/**
134+
* Manages a map to store Dense Scalar with different value_type by a
135+
* user-specified value. The workspace is buffered and reused internally to
136+
* avoid repeated allocations. Copying an instance will only yield an empty
137+
* object since copying the cached vector would not make sense. The stored
138+
* object is always mutable, so the cache can be used in a const-context.
139+
*
140+
* @internal The struct is present to wrap cache-like buffer storage that will
141+
* not be copied when the outer object gets copied.
142+
*/
143+
struct ScalarCache {
144+
ScalarCache(std::shared_ptr<const Executor> executor, double scalar_value);
145+
~ScalarCache() = default;
146+
ScalarCache(const ScalarCache& other);
147+
ScalarCache(ScalarCache&& other) noexcept;
148+
ScalarCache& operator=(const ScalarCache& other);
149+
ScalarCache& operator=(ScalarCache&& other) noexcept;
150+
std::shared_ptr<const Executor> exec;
151+
double value;
152+
mutable std::map<std::string, std::shared_ptr<const gko::LinOp>> scalars;
153+
154+
155+
/**
156+
* Pointer access to the underlying vector with specific type.
157+
*
158+
* @return Pointer to the vector view.
159+
*/
160+
template <typename ValueType>
161+
std::shared_ptr<const matrix::Dense<ValueType>> get() const;
162+
};
163+
164+
125165
} // namespace detail
126166
} // namespace gko
127167

include/ginkgo/core/distributed/matrix.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ class Matrix
681681
private:
682682
std::shared_ptr<RowGatherer<LocalIndexType>> row_gatherer_;
683683
index_map<local_index_type, global_index_type> imap_;
684-
gko::detail::DenseCache<value_type> one_scalar_;
684+
gko::detail::ScalarCache one_scalar_;
685685
gko::detail::GenericDenseCache recv_buffer_;
686686
gko::detail::GenericDenseCache host_recv_buffer_;
687687
std::shared_ptr<LinOp> local_mtx_;

0 commit comments

Comments
 (0)