Skip to content

Commit 0725db0

Browse files
committed
return the vectors in init_recv_buffers and merge init and get together
1 parent 2f899c1 commit 0725db0

File tree

4 files changed

+100
-103
lines changed

4 files changed

+100
-103
lines changed

core/distributed/matrix.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "ginkgo/core/distributed/matrix.hpp"
66

7+
#include <utility>
8+
79
#include <ginkgo/core/base/array.hpp>
810
#include <ginkgo/core/base/precision_dispatch.hpp>
911
#include <ginkgo/core/distributed/assembly.hpp>
@@ -416,21 +418,28 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
416418
}
417419

418420

419-
template <typename LocalIndexType>
420-
void init_recv_buffers(std::shared_ptr<const Executor> exec,
421-
const RowGatherer<LocalIndexType>* row_gatherer,
422-
size_type num_cols,
423-
const detail::GenericVectorCache& buffer,
424-
const detail::GenericVectorCache& host_buffer)
421+
template <typename ValueType, typename LocalIndexType>
422+
std::pair<std::shared_ptr<Vector<ValueType>>,
423+
std::shared_ptr<Vector<ValueType>>>
424+
init_recv_buffers(std::shared_ptr<const Executor> exec,
425+
const RowGatherer<LocalIndexType>* row_gatherer,
426+
size_type num_cols, const detail::GenericVectorCache& buffer,
427+
const detail::GenericVectorCache& host_buffer)
425428
{
429+
auto comm =
430+
row_gatherer->get_collective_communicator()->get_base_communicator();
426431
auto global_recv_dim =
427432
dim<2>{static_cast<size_type>(row_gatherer->get_size()[0]), num_cols};
428433
auto local_recv_dim = dim<2>{
429434
static_cast<size_type>(
430435
row_gatherer->get_collective_communicator()->get_recv_size()),
431436
num_cols};
432-
buffer.init(exec, global_recv_dim, local_recv_dim);
433-
host_buffer.init(exec->get_master(), global_recv_dim, local_recv_dim);
437+
438+
auto vector = buffer.template get<ValueType>(exec, comm, global_recv_dim,
439+
local_recv_dim);
440+
auto host_vector = host_buffer.template get<ValueType>(
441+
exec->get_master(), comm, global_recv_dim, local_recv_dim);
442+
return std::make_pair(vector, host_vector);
434443
}
435444

436445

@@ -455,11 +464,10 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
455464

456465
auto exec = this->get_executor();
457466
auto comm = this->get_communicator();
458-
init_recv_buffers(exec, row_gatherer_.get(), dense_b->get_size()[1],
459-
recv_buffer_, host_recv_buffer_);
460-
auto host_recv_vector =
461-
host_recv_buffer_.template get<b_value_type>(comm);
462-
auto recv_vector = recv_buffer_.template get<b_value_type>(comm);
467+
auto [recv_vector, host_recv_vector] =
468+
init_recv_buffers<b_value_type>(
469+
exec, row_gatherer_.get(), dense_b->get_size()[1],
470+
recv_buffer_, host_recv_buffer_);
463471
auto recv_ptr = mpi::requires_host_buffer(exec, comm)
464472
? host_recv_vector.get()
465473
: recv_vector.get();
@@ -503,11 +511,10 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
503511

504512
auto exec = this->get_executor();
505513
auto comm = this->get_communicator();
506-
init_recv_buffers(exec, row_gatherer_.get(), dense_b->get_size()[1],
507-
recv_buffer_, host_recv_buffer_);
508-
auto host_recv_vector =
509-
host_recv_buffer_.template get<b_value_type>(comm);
510-
auto recv_vector = recv_buffer_.template get<b_value_type>(comm);
514+
auto [recv_vector, host_recv_vector] =
515+
init_recv_buffers<b_value_type>(
516+
exec, row_gatherer_.get(), dense_b->get_size()[1],
517+
recv_buffer_, host_recv_buffer_);
511518
auto recv_ptr = mpi::requires_host_buffer(exec, comm)
512519
? host_recv_vector.get()
513520
: recv_vector.get();
@@ -552,10 +559,9 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::col_scale(
552559
make_const_array_view(exec, n_local_cols,
553560
scaling_factors_ptr->get_const_local_values()));
554561

555-
init_recv_buffers(exec, row_gatherer_.get(), scaling_factors->get_size()[1],
556-
recv_buffer_, host_recv_buffer_);
557-
auto host_recv_vector = host_recv_buffer_.template get<ValueType>(comm);
558-
auto recv_vector = recv_buffer_.template get<ValueType>(comm);
562+
auto [recv_vector, host_recv_vector] = init_recv_buffers<ValueType>(
563+
exec, row_gatherer_.get(), scaling_factors->get_size()[1], recv_buffer_,
564+
host_recv_buffer_);
559565
auto recv_ptr = mpi::requires_host_buffer(exec, comm)
560566
? host_recv_vector.get()
561567
: recv_vector.get();

core/distributed/vector_cache.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,33 +68,27 @@ GenericVectorCache& GenericVectorCache::operator=(GenericVectorCache&&) noexcept
6868
}
6969

7070

71-
void GenericVectorCache::init(std::shared_ptr<const Executor> exec,
72-
dim<2> global_size, dim<2> local_size) const
73-
{
74-
exec_ = exec;
75-
global_size_ = global_size;
76-
local_size_ = local_size;
77-
}
78-
7971
template <typename ValueType>
8072
std::shared_ptr<Vector<ValueType>> GenericVectorCache::get(
81-
gko::experimental::mpi::communicator comm) const
73+
std::shared_ptr<const Executor> exec,
74+
gko::experimental::mpi::communicator comm, dim<2> global_size,
75+
dim<2> local_size) const
8276
{
83-
auto required_size = local_size_[0] * local_size_[1] * sizeof(ValueType);
84-
if (exec_ != workspace.get_executor() ||
77+
auto required_size = local_size[0] * local_size[1] * sizeof(ValueType);
78+
if (exec != workspace.get_executor() ||
8579
required_size > workspace.get_size()) {
86-
auto new_workspace = gko::array<char>(exec_, required_size);
80+
auto new_workspace = gko::array<char>(exec, required_size);
8781
// We use swap here, otherwise array copy/move between different
8882
// executor will keep the original executor.
8983
std::swap(workspace, new_workspace);
9084
}
9185
return Vector<ValueType>::create(
92-
exec_, comm, global_size_,
86+
exec, comm, global_size,
9387
matrix::Dense<ValueType>::create(
94-
exec_, local_size_,
95-
make_array_view(exec_, local_size_[0] * local_size_[1],
88+
exec, local_size,
89+
make_array_view(exec, local_size[0] * local_size[1],
9690
reinterpret_cast<ValueType*>(workspace.get_data())),
97-
local_size_[1]));
91+
local_size[1]));
9892
}
9993

10094

@@ -110,9 +104,11 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_VECTOR_CACHE);
110104

111105
class GenericVectorCache;
112106

113-
#define GKO_DECLARE_GENERIC_VECTOR_CACHE_GET(_type) \
114-
std::shared_ptr<Vector<_type>> GenericVectorCache::get( \
115-
gko::experimental::mpi::communicator comm) const
107+
#define GKO_DECLARE_GENERIC_VECTOR_CACHE_GET(_type) \
108+
std::shared_ptr<Vector<_type>> GenericVectorCache::get( \
109+
std::shared_ptr<const Executor> exec, \
110+
gko::experimental::mpi::communicator comm, dim<2> global_size, \
111+
dim<2> local_size) const
116112

117113
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GENERIC_VECTOR_CACHE_GET);
118114

core/test/mpi/distributed/vector_cache.cpp

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,9 @@ TYPED_TEST(GenericVectorCache, GenericCanInitWithSize)
359359
{
360360
using value_type = typename TestFixture::value_type;
361361

362-
this->cache.init(this->ref, this->default_global_size,
363-
this->default_local_size);
364-
// only initialize when knowning the type
365-
auto buffer = this->cache.template get<value_type>(this->comm);
362+
auto buffer = this->cache.template get<value_type>(
363+
this->ref, this->comm, this->default_global_size,
364+
this->default_local_size);
366365

367366
ASSERT_NE(buffer, nullptr);
368367
GKO_ASSERT_EQUAL_DIMENSIONS(buffer->get_size(), this->default_global_size);
@@ -376,14 +375,16 @@ TYPED_TEST(GenericVectorCache, GenericCanInitWithSize)
376375
TYPED_TEST(GenericVectorCache, SecondInitWithSameSizeIsNoOp)
377376
{
378377
using value_type = typename TestFixture::value_type;
379-
this->cache.init(this->ref, this->default_global_size,
380-
this->default_local_size);
381-
auto buffer = this->cache.template get<value_type>(this->comm);
378+
auto buffer = this->cache.template get<value_type>(
379+
this->ref, this->comm, this->default_global_size,
380+
this->default_local_size);
382381
auto array_ptr =
383382
generic_accessor::get_workspace(this->cache).get_const_data();
384383
auto array_size = generic_accessor::get_workspace(this->cache).get_size();
385384

386-
auto second_buffer = this->cache.template get<value_type>(this->comm);
385+
auto second_buffer = this->cache.template get<value_type>(
386+
this->ref, this->comm, this->default_global_size,
387+
this->default_local_size);
387388

388389
ASSERT_NE(second_buffer, nullptr);
389390
GKO_ASSERT_EQUAL_DIMENSIONS(second_buffer->get_size(),
@@ -402,17 +403,17 @@ TYPED_TEST(GenericVectorCache, SecondInitWithSameSizeIsNoOp)
402403
TYPED_TEST(GenericVectorCache, SecondInitWithTheSmallEqSizeIsNoOp)
403404
{
404405
using value_type = typename TestFixture::value_type;
405-
gko::dim<2> second_local_size{1, 1};
406-
gko::dim<2> second_global_size{this->num_ranks, 1};
407-
this->cache.init(this->ref, this->default_global_size,
408-
this->default_local_size);
409-
auto buffer = this->cache.template get<value_type>(this->comm);
406+
gko::dim<2> second_local_size(1, 1);
407+
gko::dim<2> second_global_size(this->num_ranks, 1);
408+
auto buffer = this->cache.template get<value_type>(
409+
this->ref, this->comm, this->default_global_size,
410+
this->default_local_size);
410411
auto array_ptr =
411412
generic_accessor::get_workspace(this->cache).get_const_data();
412413
auto array_size = generic_accessor::get_workspace(this->cache).get_size();
413414

414-
this->cache.init(this->ref, second_global_size, second_local_size);
415-
auto second_buffer = this->cache.template get<value_type>(this->comm);
415+
auto second_buffer = this->cache.template get<value_type>(
416+
this->ref, this->comm, second_global_size, second_local_size);
416417

417418
ASSERT_NE(second_buffer, nullptr);
418419
GKO_ASSERT_EQUAL_DIMENSIONS(second_buffer->get_size(), second_global_size);
@@ -430,18 +431,18 @@ TYPED_TEST(GenericVectorCache, SecondInitWithTheSmallEqSizeIsNoOp)
430431
TYPED_TEST(GenericVectorCache, SecondInitWithTheLargerSizeRecreate)
431432
{
432433
using value_type = typename TestFixture::value_type;
433-
gko::dim<2> second_local_size{this->rank + 2, 3};
434-
gko::dim<2> second_global_size{this->num_ranks * (this->num_ranks + 3) / 2,
435-
3};
436-
this->cache.init(this->ref, this->default_global_size,
437-
this->default_local_size);
438-
auto buffer = this->cache.template get<value_type>(this->comm);
434+
gko::dim<2> second_local_size(this->rank + 2, 3);
435+
gko::dim<2> second_global_size(this->num_ranks * (this->num_ranks + 3) / 2,
436+
3);
437+
auto buffer = this->cache.template get<value_type>(
438+
this->ref, this->comm, this->default_global_size,
439+
this->default_local_size);
439440
auto array_ptr =
440441
generic_accessor::get_workspace(this->cache).get_const_data();
441442
auto array_size = generic_accessor::get_workspace(this->cache).get_size();
442443

443-
this->cache.init(this->ref, second_global_size, second_local_size);
444-
auto second_buffer = this->cache.template get<value_type>(this->comm);
444+
auto second_buffer = this->cache.template get<value_type>(
445+
this->ref, this->comm, second_global_size, second_local_size);
445446

446447
ASSERT_NE(second_buffer, nullptr);
447448
GKO_ASSERT_EQUAL_DIMENSIONS(second_buffer->get_size(), second_global_size);
@@ -460,14 +461,16 @@ TYPED_TEST(GenericVectorCache, GenericCanInitWithSizeAndType)
460461
{
461462
using value_type = typename TestFixture::value_type;
462463
using another_type = gko::next_precision<value_type>;
463-
this->cache.init(this->ref, this->default_global_size,
464-
this->default_local_size);
465-
auto buffer = this->cache.template get<value_type>(this->comm);
464+
auto buffer = this->cache.template get<value_type>(
465+
this->ref, this->comm, this->default_global_size,
466+
this->default_local_size);
466467
auto array_ptr =
467468
generic_accessor::get_workspace(this->cache).get_const_data();
468469
auto array_size = generic_accessor::get_workspace(this->cache).get_size();
469470

470-
auto second_buffer = this->cache.template get<another_type>(this->comm);
471+
auto second_buffer = this->cache.template get<another_type>(
472+
this->ref, this->comm, this->default_global_size,
473+
this->default_local_size);
471474

472475
ASSERT_NE(second_buffer, nullptr);
473476
GKO_ASSERT_EQUAL_DIMENSIONS(second_buffer->get_size(),
@@ -497,16 +500,16 @@ TYPED_TEST(GenericVectorCache, GenericCanInitWithDifferentExecutor)
497500
{
498501
using value_type = typename TestFixture::value_type;
499502
auto another_ref = gko::ReferenceExecutor::create();
500-
this->cache.init(this->ref, this->default_global_size,
501-
this->default_local_size);
502-
auto buffer = this->cache.template get<value_type>(this->comm);
503+
auto buffer = this->cache.template get<value_type>(
504+
this->ref, this->comm, this->default_global_size,
505+
this->default_local_size);
503506
auto array_ptr =
504507
generic_accessor::get_workspace(this->cache).get_const_data();
505508
auto array_size = generic_accessor::get_workspace(this->cache).get_size();
506509

507-
this->cache.init(another_ref, this->default_global_size,
508-
this->default_local_size);
509-
auto second_buffer = this->cache.template get<value_type>(this->comm);
510+
auto second_buffer = this->cache.template get<value_type>(
511+
another_ref, this->comm, this->default_global_size,
512+
this->default_local_size);
510513

511514
ASSERT_NE(second_buffer, nullptr);
512515
GKO_ASSERT_EQUAL_DIMENSIONS(second_buffer->get_size(),
@@ -524,9 +527,9 @@ TYPED_TEST(GenericVectorCache, GenericCanInitWithDifferentExecutor)
524527
TYPED_TEST(GenericVectorCache, WorkspaceIsNotCopied)
525528
{
526529
using value_type = typename TestFixture::value_type;
527-
this->cache.init(this->ref, this->default_global_size,
528-
this->default_local_size);
529-
auto buffer = this->cache.template get<value_type>(this->comm);
530+
auto buffer = this->cache.template get<value_type>(
531+
this->ref, this->comm, this->default_global_size,
532+
this->default_local_size);
530533

531534
gko::experimental::distributed::detail::GenericVectorCache cache(
532535
this->cache);
@@ -539,9 +542,9 @@ TYPED_TEST(GenericVectorCache, WorkspaceIsNotCopied)
539542
TYPED_TEST(GenericVectorCache, WorkspaceIsNotMoved)
540543
{
541544
using value_type = typename TestFixture::value_type;
542-
this->cache.init(this->ref, this->default_global_size,
543-
this->default_local_size);
544-
auto buffer = this->cache.template get<value_type>(this->comm);
545+
auto buffer = this->cache.template get<value_type>(
546+
this->ref, this->comm, this->default_global_size,
547+
this->default_local_size);
545548

546549
gko::experimental::distributed::detail::GenericVectorCache cache(
547550
std::move(this->cache));
@@ -554,9 +557,9 @@ TYPED_TEST(GenericVectorCache, WorkspaceIsNotMoved)
554557
TYPED_TEST(GenericVectorCache, WorkspaceIsNotCopyAssigned)
555558
{
556559
using value_type = typename TestFixture::value_type;
557-
this->cache.init(this->ref, this->default_global_size,
558-
this->default_local_size);
559-
auto buffer = this->cache.template get<value_type>(this->comm);
560+
auto buffer = this->cache.template get<value_type>(
561+
this->ref, this->comm, this->default_global_size,
562+
this->default_local_size);
560563
gko::experimental::distributed::detail::GenericVectorCache cache;
561564

562565
cache = this->cache;
@@ -569,9 +572,9 @@ TYPED_TEST(GenericVectorCache, WorkspaceIsNotCopyAssigned)
569572
TYPED_TEST(GenericVectorCache, WorkspaceIsNotMoveAssigned)
570573
{
571574
using value_type = typename TestFixture::value_type;
572-
this->cache.init(this->ref, this->default_global_size,
573-
this->default_local_size);
574-
auto buffer = this->cache.template get<value_type>(this->comm);
575+
auto buffer = this->cache.template get<value_type>(
576+
this->ref, this->comm, this->default_global_size,
577+
this->default_local_size);
575578
gko::experimental::distributed::detail::GenericVectorCache cache;
576579

577580
cache = std::move(this->cache);

include/ginkgo/core/distributed/vector_cache.hpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -123,36 +123,28 @@ class GenericVectorCache {
123123
GenericVectorCache& operator=(GenericVectorCache&&) noexcept;
124124

125125
/**
126-
* Initializes the buffered vector configuration.
126+
* Pointer access to the distributed vector view with specific type on the
127+
* underlying workspace Initializes the workspace, if
128+
* - the workspace is null,
129+
* - the sizes differ,
130+
* - the executor differs.
127131
*
128132
* @param exec Executor associated with the buffered vector
133+
* @param comm Communicator associated with the buffered vector
129134
* @param global_size Global size of the buffered vector
130135
* @param local_size Processor-local size of the buffered vector, uses
131136
* local_size[1] as the stride
132-
*/
133-
void init(std::shared_ptr<const Executor> exec, dim<2> global_size,
134-
dim<2> local_size) const;
135-
136-
/**
137-
* Pointer access to the underlying vector with specific type.
138-
* Initializes the buffered vector, if
139-
* - the current vector is null,
140-
* - the sizes differ,
141-
* - the executor differs.
142-
*
143-
* @param comm Communicator associated with the buffered vector
144137
*
145138
* @return Pointer to the vector view.
146139
*/
147140
template <typename ValueType>
148141
std::shared_ptr<Vector<ValueType>> get(
149-
gko::experimental::mpi::communicator comm) const;
142+
std::shared_ptr<const Executor> exec,
143+
gko::experimental::mpi::communicator comm, dim<2> global_size,
144+
dim<2> local_size) const;
150145

151146
private:
152147
mutable array<char> workspace;
153-
mutable std::shared_ptr<const Executor> exec_;
154-
mutable dim<2> global_size_;
155-
mutable dim<2> local_size_;
156148
};
157149

158150

0 commit comments

Comments
 (0)