@@ -49,13 +49,12 @@ make_temporary_conversion(Ptr&& matrix)
49
49
using Pointee = detail::pointee<Ptr>;
50
50
using Dense = matrix::Dense<ValueType>;
51
51
using NextDense = matrix::Dense<next_precision<ValueType>>;
52
- using Next2Dense = matrix::Dense<next_precision<ValueType, 2 >>;
53
- using Next3Dense = matrix::Dense<next_precision<ValueType, 3 >>;
52
+ using NextNextDense =
53
+ matrix::Dense<next_precision<next_precision< ValueType> >>;
54
54
using MaybeConstDense =
55
55
std::conditional_t <std::is_const<Pointee>::value, const Dense, Dense>;
56
- auto result =
57
- detail::temporary_conversion<MaybeConstDense>::template create<
58
- NextDense, Next2Dense, Next3Dense>(matrix);
56
+ auto result = detail::temporary_conversion<
57
+ MaybeConstDense>::template create<NextDense, NextNextDense>(matrix);
59
58
if (!result) {
60
59
GKO_NOT_SUPPORTED (matrix);
61
60
}
@@ -230,17 +229,14 @@ void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
230
229
#ifdef GINKGO_MIXED_PRECISION
231
230
using fst_type = matrix::Dense<ValueType>;
232
231
using snd_type = matrix::Dense<next_precision<ValueType>>;
233
- using trd_type = matrix::Dense<next_precision<ValueType, 2 >>;
234
- using fth_type = matrix::Dense<next_precision<ValueType, 3 >>;
232
+ using trd_type = matrix::Dense<next_precision<next_precision<ValueType>>>;
235
233
auto dispatch_out_vector = [&](auto dense_in) {
236
234
if (auto dense_out = dynamic_cast <fst_type*>(out)) {
237
235
fn (dense_in, dense_out);
238
236
} else if (auto dense_out = dynamic_cast <snd_type*>(out)) {
239
237
fn (dense_in, dense_out);
240
238
} else if (auto dense_out = dynamic_cast <trd_type*>(out)) {
241
239
fn (dense_in, dense_out);
242
- } else if (auto dense_out = dynamic_cast <fth_type*>(out)) {
243
- fn (dense_in, dense_out);
244
240
} else {
245
241
GKO_NOT_SUPPORTED (out);
246
242
}
@@ -251,8 +247,6 @@ void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
251
247
dispatch_out_vector (dense_in);
252
248
} else if (auto dense_in = dynamic_cast <const trd_type*>(in)) {
253
249
dispatch_out_vector (dense_in);
254
- } else if (auto dense_in = dynamic_cast <const fth_type*>(in)) {
255
- dispatch_out_vector (dense_in);
256
250
} else {
257
251
GKO_NOT_SUPPORTED (in);
258
252
}
@@ -347,8 +341,7 @@ gko::detail::temporary_conversion<Vector<ValueType>> make_temporary_conversion(
347
341
auto result =
348
342
gko::detail::temporary_conversion<Vector<ValueType>>::template create<
349
343
Vector<next_precision<ValueType>>,
350
- Vector<next_precision<ValueType, 2 >>,
351
- Vector<next_precision<ValueType, 3 >>>(matrix);
344
+ Vector<next_precision<next_precision<ValueType>>>>(matrix);
352
345
if (!result) {
353
346
GKO_NOT_SUPPORTED (matrix);
354
347
}
@@ -365,8 +358,8 @@ make_temporary_conversion(const LinOp* matrix)
365
358
{
366
359
auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
367
360
template create<Vector<next_precision<ValueType>>,
368
- Vector<next_precision<ValueType, 2 >>,
369
- Vector<next_precision<ValueType, 3 >>>( matrix);
361
+ Vector<next_precision<next_precision< ValueType>>>>(
362
+ matrix);
370
363
if (!result) {
371
364
GKO_NOT_SUPPORTED (matrix);
372
365
}
@@ -395,6 +388,39 @@ void precision_dispatch(Function fn, Args*... linops)
395
388
}
396
389
397
390
391
+ template <typename ValueType, typename Function>
392
+ void mixed_precision_dispatch (Function fn, const LinOp* in, LinOp* out)
393
+ {
394
+ #ifdef GINKGO_MIXED_PRECISION
395
+ using fst_type = Vector<ValueType>;
396
+ using snd_type = Vector<next_precision<ValueType>>;
397
+ using trd_type = Vector<next_precision<next_precision<ValueType>>>;
398
+ auto dispatch_out_vector = [&](auto vector_in) {
399
+ if (auto vector_out = dynamic_cast <fst_type*>(out)) {
400
+ fn (vector_in, vector_out);
401
+ } else if (auto vector_out = dynamic_cast <snd_type*>(out)) {
402
+ fn (vector_in, vector_out);
403
+ } else if (auto vector_out = dynamic_cast <trd_type*>(out)) {
404
+ fn (vector_in, vector_out);
405
+ } else {
406
+ GKO_NOT_SUPPORTED (out);
407
+ }
408
+ };
409
+ if (auto vector_in = dynamic_cast <const fst_type*>(in)) {
410
+ dispatch_out_vector (vector_in);
411
+ } else if (auto vector_in = dynamic_cast <const snd_type*>(in)) {
412
+ dispatch_out_vector (vector_in);
413
+ } else if (auto vector_in = dynamic_cast <const trd_type*>(in)) {
414
+ dispatch_out_vector (vector_in);
415
+ } else {
416
+ GKO_NOT_SUPPORTED (in);
417
+ }
418
+ #else
419
+ precision_dispatch<ValueType>(fn, in, out);
420
+ #endif
421
+ }
422
+
423
+
398
424
/* *
399
425
* Calls the given function with the given LinOps temporarily converted to
400
426
* experimental::distributed::Vector<ValueType>* as parameters.
@@ -428,6 +454,27 @@ void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
428
454
}
429
455
430
456
457
+ template <typename ValueType, typename Function>
458
+ void mixed_precision_dispatch_real_complex (Function fn, const LinOp* in,
459
+ LinOp* out)
460
+ {
461
+ auto complex_to_real = !(
462
+ is_complex<ValueType>() ||
463
+ dynamic_cast <const ConvertibleTo<experimental::distributed::Vector<>>*>(
464
+ in));
465
+ if (complex_to_real) {
466
+ distributed::mixed_precision_dispatch<to_complex<ValueType>>(
467
+ [&fn](auto vector_in, auto vector_out) {
468
+ fn (vector_in->create_real_view ().get (),
469
+ vector_out->create_real_view ().get ());
470
+ },
471
+ in, out);
472
+ } else {
473
+ distributed::mixed_precision_dispatch<ValueType>(fn, in, out);
474
+ }
475
+ }
476
+
477
+
431
478
/* *
432
479
* @copydoc precision_dispatch_real_complex(Function, const LinOp*, LinOp*)
433
480
*/
0 commit comments