Skip to content

Commit 83a7beb

Browse files
committed
enable distributed matrix mixed precision
1 parent e4bc8db commit 83a7beb

File tree

2 files changed

+84
-31
lines changed

2 files changed

+84
-31
lines changed

core/distributed/matrix.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,14 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
438438
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
439439
const LinOp* b, LinOp* x) const
440440
{
441-
distributed::precision_dispatch_real_complex<ValueType>(
441+
distributed::mixed_precision_dispatch_real_complex<ValueType>(
442442
[this](const auto dense_b, auto dense_x) {
443-
auto x_exec = dense_x->get_executor();
444443
using x_value_type =
445444
typename std::decay_t<decltype(*dense_x)>::value_type;
446-
auto local_x = gko::matrix::Dense<ValueType>::create(
445+
using b_value_type =
446+
typename std::decay_t<decltype(*dense_b)>::value_type;
447+
auto x_exec = dense_x->get_executor();
448+
auto local_x = gko::matrix::Dense<x_value_type>::create(
447449
x_exec, dense_x->get_local_vector()->get_size(),
448450
gko::make_array_view(
449451
x_exec,
@@ -456,8 +458,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
456458
init_recv_buffers(exec, row_gatherer_.get(), dense_b->get_size()[1],
457459
recv_buffer_, host_recv_buffer_);
458460
auto host_recv_vector =
459-
host_recv_buffer_.template get<ValueType>(comm);
460-
auto recv_vector = recv_buffer_.template get<ValueType>(comm);
461+
host_recv_buffer_.template get<b_value_type>(comm);
462+
auto recv_vector = recv_buffer_.template get<b_value_type>(comm);
461463
auto recv_ptr = mpi::requires_host_buffer(exec, comm)
462464
? host_recv_vector.get()
463465
: recv_vector.get();
@@ -481,13 +483,17 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
481483
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
482484
const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const
483485
{
484-
distributed::precision_dispatch_real_complex<ValueType>(
485-
[this](const auto local_alpha, const auto dense_b,
486-
const auto local_beta, auto dense_x) {
487-
const auto x_exec = dense_x->get_executor();
486+
distributed::mixed_precision_dispatch_real_complex<ValueType>(
487+
[this, alpha, beta](const auto dense_b, auto dense_x) {
488488
using x_value_type =
489489
typename std::decay_t<decltype(*dense_x)>::value_type;
490-
auto local_x = gko::matrix::Dense<ValueType>::create(
490+
using b_value_type =
491+
typename std::decay_t<decltype(*dense_b)>::value_type;
492+
const auto x_exec = dense_x->get_executor();
493+
auto local_alpha = gko::make_temporary_conversion<ValueType>(alpha);
494+
auto local_beta =
495+
gko::make_temporary_conversion<x_value_type>(beta);
496+
auto local_x = gko::matrix::Dense<x_value_type>::create(
491497
x_exec, dense_x->get_local_vector()->get_size(),
492498
gko::make_array_view(
493499
x_exec,
@@ -500,24 +506,24 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
500506
init_recv_buffers(exec, row_gatherer_.get(), dense_b->get_size()[1],
501507
recv_buffer_, host_recv_buffer_);
502508
auto host_recv_vector =
503-
host_recv_buffer_.template get<ValueType>(comm);
504-
auto recv_vector = recv_buffer_.template get<ValueType>(comm);
509+
host_recv_buffer_.template get<b_value_type>(comm);
510+
auto recv_vector = recv_buffer_.template get<b_value_type>(comm);
505511
auto recv_ptr = mpi::requires_host_buffer(exec, comm)
506512
? host_recv_vector.get()
507513
: recv_vector.get();
508514
auto req = this->row_gatherer_->apply_async(dense_b, recv_ptr);
509-
local_mtx_->apply(local_alpha, dense_b->get_local_vector(),
510-
local_beta, local_x);
515+
local_mtx_->apply(local_alpha.get(), dense_b->get_local_vector(),
516+
local_beta.get(), local_x);
511517
req.wait();
512518

513519
if (recv_ptr != recv_vector.get()) {
514520
recv_vector->copy_from(host_recv_vector);
515521
}
516522
non_local_mtx_->apply(
517-
local_alpha, recv_vector->get_local_vector(),
523+
local_alpha.get(), recv_vector->get_local_vector(),
518524
one_scalar_.template get<x_value_type>().get(), local_x);
519525
},
520-
alpha, b, beta, x);
526+
b, x);
521527
}
522528

523529

include/ginkgo/core/base/precision_dispatch.hpp

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,12 @@ make_temporary_conversion(Ptr&& matrix)
4949
using Pointee = detail::pointee<Ptr>;
5050
using Dense = matrix::Dense<ValueType>;
5151
using NextDense = matrix::Dense<next_precision<ValueType>>;
52-
using Next2Dense = matrix::Dense<next_precision<ValueType, 2>>;
53-
using Next3Dense = matrix::Dense<next_precision<ValueType, 3>>;
52+
using NextNextDense =
53+
matrix::Dense<next_precision<next_precision<ValueType>>>;
5454
using MaybeConstDense =
5555
std::conditional_t<std::is_const<Pointee>::value, const Dense, Dense>;
56-
auto result =
57-
detail::temporary_conversion<MaybeConstDense>::template create<
58-
NextDense, Next2Dense, Next3Dense>(matrix);
56+
auto result = detail::temporary_conversion<
57+
MaybeConstDense>::template create<NextDense, NextNextDense>(matrix);
5958
if (!result) {
6059
GKO_NOT_SUPPORTED(matrix);
6160
}
@@ -230,17 +229,14 @@ void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
230229
#ifdef GINKGO_MIXED_PRECISION
231230
using fst_type = matrix::Dense<ValueType>;
232231
using snd_type = matrix::Dense<next_precision<ValueType>>;
233-
using trd_type = matrix::Dense<next_precision<ValueType, 2>>;
234-
using fth_type = matrix::Dense<next_precision<ValueType, 3>>;
232+
using trd_type = matrix::Dense<next_precision<next_precision<ValueType>>>;
235233
auto dispatch_out_vector = [&](auto dense_in) {
236234
if (auto dense_out = dynamic_cast<fst_type*>(out)) {
237235
fn(dense_in, dense_out);
238236
} else if (auto dense_out = dynamic_cast<snd_type*>(out)) {
239237
fn(dense_in, dense_out);
240238
} else if (auto dense_out = dynamic_cast<trd_type*>(out)) {
241239
fn(dense_in, dense_out);
242-
} else if (auto dense_out = dynamic_cast<fth_type*>(out)) {
243-
fn(dense_in, dense_out);
244240
} else {
245241
GKO_NOT_SUPPORTED(out);
246242
}
@@ -251,8 +247,6 @@ void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
251247
dispatch_out_vector(dense_in);
252248
} else if (auto dense_in = dynamic_cast<const trd_type*>(in)) {
253249
dispatch_out_vector(dense_in);
254-
} else if (auto dense_in = dynamic_cast<const fth_type*>(in)) {
255-
dispatch_out_vector(dense_in);
256250
} else {
257251
GKO_NOT_SUPPORTED(in);
258252
}
@@ -347,8 +341,7 @@ gko::detail::temporary_conversion<Vector<ValueType>> make_temporary_conversion(
347341
auto result =
348342
gko::detail::temporary_conversion<Vector<ValueType>>::template create<
349343
Vector<next_precision<ValueType>>,
350-
Vector<next_precision<ValueType, 2>>,
351-
Vector<next_precision<ValueType, 3>>>(matrix);
344+
Vector<next_precision<next_precision<ValueType>>>>(matrix);
352345
if (!result) {
353346
GKO_NOT_SUPPORTED(matrix);
354347
}
@@ -365,8 +358,8 @@ make_temporary_conversion(const LinOp* matrix)
365358
{
366359
auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
367360
template create<Vector<next_precision<ValueType>>,
368-
Vector<next_precision<ValueType, 2>>,
369-
Vector<next_precision<ValueType, 3>>>(matrix);
361+
Vector<next_precision<next_precision<ValueType>>>>(
362+
matrix);
370363
if (!result) {
371364
GKO_NOT_SUPPORTED(matrix);
372365
}
@@ -395,6 +388,39 @@ void precision_dispatch(Function fn, Args*... linops)
395388
}
396389

397390

391+
template <typename ValueType, typename Function>
392+
void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
393+
{
394+
#ifdef GINKGO_MIXED_PRECISION
395+
using fst_type = Vector<ValueType>;
396+
using snd_type = Vector<next_precision<ValueType>>;
397+
using trd_type = Vector<next_precision<next_precision<ValueType>>>;
398+
auto dispatch_out_vector = [&](auto vector_in) {
399+
if (auto vector_out = dynamic_cast<fst_type*>(out)) {
400+
fn(vector_in, vector_out);
401+
} else if (auto vector_out = dynamic_cast<snd_type*>(out)) {
402+
fn(vector_in, vector_out);
403+
} else if (auto vector_out = dynamic_cast<trd_type*>(out)) {
404+
fn(vector_in, vector_out);
405+
} else {
406+
GKO_NOT_SUPPORTED(out);
407+
}
408+
};
409+
if (auto vector_in = dynamic_cast<const fst_type*>(in)) {
410+
dispatch_out_vector(vector_in);
411+
} else if (auto vector_in = dynamic_cast<const snd_type*>(in)) {
412+
dispatch_out_vector(vector_in);
413+
} else if (auto vector_in = dynamic_cast<const trd_type*>(in)) {
414+
dispatch_out_vector(vector_in);
415+
} else {
416+
GKO_NOT_SUPPORTED(in);
417+
}
418+
#else
419+
precision_dispatch<ValueType>(fn, in, out);
420+
#endif
421+
}
422+
423+
398424
/**
399425
* Calls the given function with the given LinOps temporarily converted to
400426
* experimental::distributed::Vector<ValueType>* as parameters.
@@ -428,6 +454,27 @@ void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
428454
}
429455

430456

457+
template <typename ValueType, typename Function>
458+
void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
459+
LinOp* out)
460+
{
461+
auto complex_to_real = !(
462+
is_complex<ValueType>() ||
463+
dynamic_cast<const ConvertibleTo<experimental::distributed::Vector<>>*>(
464+
in));
465+
if (complex_to_real) {
466+
distributed::mixed_precision_dispatch<to_complex<ValueType>>(
467+
[&fn](auto vector_in, auto vector_out) {
468+
fn(vector_in->create_real_view().get(),
469+
vector_out->create_real_view().get());
470+
},
471+
in, out);
472+
} else {
473+
distributed::mixed_precision_dispatch<ValueType>(fn, in, out);
474+
}
475+
}
476+
477+
431478
/**
432479
* @copydoc precision_dispatch_real_complex(Function, const LinOp*, LinOp*)
433480
*/

0 commit comments

Comments
 (0)