@@ -306,7 +306,7 @@ void perturb (Field& f,
306
306
}
307
307
}
308
308
309
- template <typename ST>
309
+ template <typename ST, int AVG >
310
310
void horiz_contraction (const Field &f_out, const Field &f_in,
311
311
const Field &weight, const ekat::Comm *comm) {
312
312
using KT = ekat::KokkosTypes<DefaultDevice>;
@@ -318,34 +318,63 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
318
318
auto l_out = f_out.get_header ().get_identifier ().get_layout ();
319
319
auto l_in = f_in.get_header ().get_identifier ().get_layout ();
320
320
321
+ auto is_masked = f_in.get_header ().has_extra_data (" mask_data" );
322
+
321
323
auto v_w = weight.get_view <const ST *>();
322
324
323
325
const int ncols = l_in.dim (0 );
324
326
327
+ bool is_avg = AVG; // 1 is avg; 0 is sum
328
+
325
329
switch (l_in.rank ()) {
326
330
case 1 : {
327
331
auto v_in = f_in.get_view <const ST *>();
332
+ auto v_m = is_masked ? f_in.get_header ().get_extra_data <Field>(" mask_data" ).get_view <const ST *>() : v_in;
328
333
auto v_out = f_out.get_view <ST>();
334
+ ST n = 0 , d = 0 ;
329
335
Kokkos::parallel_reduce (
330
336
f_out.name (), RangePolicy (0 , ncols),
331
- KOKKOS_LAMBDA (const int i, ST &ls) { ls += v_w (i) * v_in (i); },
332
- v_out);
337
+ KOKKOS_LAMBDA (const int i, ST &n_acc, ST &d_acc) {
338
+ auto mask = is_masked ? v_m (i) : ST (1.0 );
339
+ n_acc += v_w (i) * v_in (i) * mask;
340
+ d_acc += v_w (i) * mask;
341
+ },
342
+ Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
343
+ if (is_avg) {
344
+ ST tmp = d != 0 ? n / d : 0 ;
345
+ Kokkos::deep_copy (v_out, tmp);
346
+ } else {
347
+ Kokkos::deep_copy (v_out, n);
348
+ }
333
349
} break ;
334
350
case 2 : {
335
351
auto v_in = f_in.get_view <const ST **>();
352
+ auto v_m = is_masked ? f_in.get_header ().get_extra_data <Field>(" mask_data" ).get_view <const ST **>() : v_in;
336
353
auto v_out = f_out.get_view <ST *>();
337
354
const int d1 = l_in.dim (1 );
338
355
auto p = ESU::get_default_team_policy (d1, ncols);
339
356
Kokkos::parallel_for (
340
357
f_out.name (), p, KOKKOS_LAMBDA (const TeamMember &tm) {
341
358
const int j = tm.league_rank ();
359
+ ST n = 0 , d = 0 ;
342
360
Kokkos::parallel_reduce (
343
361
Kokkos::TeamVectorRange (tm, ncols),
344
- [&](int i, ST &ac) { ac += v_w (i) * v_in (i, j); }, v_out (j));
362
+ [&](int i, ST &n_acc, ST &d_acc) {
363
+ auto mask = is_masked ? v_m (i, j) : ST (1.0 );
364
+ n_acc += v_w (i) * v_in (i, j) * mask;
365
+ d_acc += v_w (i) * mask;
366
+ },
367
+ Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
368
+ if (is_avg) {
369
+ v_out (j) = d != 0 ? n / d : 0 ;
370
+ } else {
371
+ v_out (j) = n;
372
+ }
345
373
});
346
374
} break ;
347
375
case 3 : {
348
376
auto v_in = f_in.get_view <const ST ***>();
377
+ auto v_m = is_masked ? f_in.get_header ().get_extra_data <Field>(" mask_data" ).get_view <const ST ***>() : v_in;
349
378
auto v_out = f_out.get_view <ST **>();
350
379
const int d1 = l_in.dim (1 );
351
380
const int d2 = l_in.dim (2 );
@@ -355,10 +384,20 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
355
384
const int idx = tm.league_rank ();
356
385
const int j = idx / d2;
357
386
const int k = idx % d2;
387
+ ST n = 0 , d = 0 ;
358
388
Kokkos::parallel_reduce (
359
389
Kokkos::TeamVectorRange (tm, ncols),
360
- [&](int i, ST &ac) { ac += v_w (i) * v_in (i, j, k); },
361
- v_out (j, k));
390
+ [&](int i, ST &n_acc, ST &d_acc) {
391
+ auto mask = is_masked ? v_m (i, j, k) : ST (1.0 );
392
+ n_acc += v_w (i) * v_in (i, j, k) * mask;
393
+ d_acc += v_w (i) * mask;
394
+ },
395
+ Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
396
+ if (is_avg) {
397
+ v_out (j, k) = d != 0 ? n / d : 0 ;
398
+ } else {
399
+ v_out (j, k) = n;
400
+ }
362
401
});
363
402
} break ;
364
403
default :
@@ -377,7 +416,7 @@ void horiz_contraction(const Field &f_out, const Field &f_in,
377
416
}
378
417
}
379
418
380
- template <typename ST>
419
+ template <typename ST, int AVG >
381
420
void vert_contraction (const Field &f_out, const Field &f_in, const Field &weight) {
382
421
using KT = ekat::KokkosTypes<DefaultDevice>;
383
422
using RangePolicy = Kokkos::RangePolicy<Field::device_t ::execution_space>;
@@ -389,6 +428,10 @@ void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight
389
428
auto l_in = f_in.get_header ().get_identifier ().get_layout ();
390
429
auto l_w = weight.get_header ().get_identifier ().get_layout ();
391
430
431
+ auto is_masked = f_in.get_header ().has_extra_data (" mask_data" );
432
+
433
+ bool is_avg = AVG; // 1 is avg, 0 is sum
434
+
392
435
const int nlevs = l_in.dim (l_in.rank () - 1 );
393
436
394
437
// To avoid duplicating code for the 1d and 2d weight cases,
@@ -404,32 +447,55 @@ void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight
404
447
405
448
switch (l_in.rank ()) {
406
449
case 1 : {
407
- auto v_w = weight.get_view <const ST *>();
408
450
auto v_in = f_in.get_view <const ST *>();
451
+ auto v_m = is_masked ? f_in.get_header ().get_extra_data <Field>(" mask_data" ).get_view <const ST *>() : v_in;
409
452
auto v_out = f_out.get_view <ST>();
453
+ ST n = 0 , d = 0 ;
410
454
Kokkos::parallel_reduce (
411
455
f_out.name (), RangePolicy (0 , nlevs),
412
- KOKKOS_LAMBDA (const int i, ST &ls) { ls += v_w (i) * v_in (i); },
413
- v_out);
456
+ KOKKOS_LAMBDA (const int i, ST &n_acc, ST &d_acc) {
457
+ auto mask = is_masked ? v_m (i) : ST (1.0 );
458
+ auto w = w1d (i);
459
+ n_acc += w * v_in (i) * mask;
460
+ d_acc += w * mask;
461
+ },
462
+ Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
463
+ if (is_avg) {
464
+ ST tmp = d != 0 ? n / d : 0 ;
465
+ Kokkos::deep_copy (v_out, tmp);
466
+ } else {
467
+ Kokkos::deep_copy (v_out, n);
468
+ }
414
469
} break ;
415
470
case 2 : {
416
471
auto v_in = f_in.get_view <const ST **>();
472
+ auto v_m = is_masked ? f_in.get_header ().get_extra_data <Field>(" mask_data" ).get_view <const ST **>() : v_in;
417
473
auto v_out = f_out.get_view <ST *>();
418
474
const int d0 = l_in.dim (0 );
419
475
auto p = ESU::get_default_team_policy (d0, nlevs);
420
476
Kokkos::parallel_for (
421
477
f_out.name (), p, KOKKOS_LAMBDA (const TeamMember &tm) {
422
478
const int i = tm.league_rank ();
479
+ ST n = 0 , d = 0 ;
423
480
Kokkos::parallel_reduce (
424
481
Kokkos::TeamVectorRange (tm, nlevs),
425
- [&](int j, ST &ac) {
426
- ac += w_is_1d ? w1d (j) * v_in (i, j) : w2d (i, j) * v_in (i, j);
482
+ [&](int j, ST &n_acc, ST &d_acc) {
483
+ auto mask = is_masked ? v_m (i, j) : ST (1.0 );
484
+ auto w = w_is_1d ? w1d (j) : w2d (i, j);
485
+ n_acc += w * v_in (i, j) * mask;
486
+ d_acc += w * mask;
427
487
},
428
- v_out (i));
488
+ Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
489
+ if (is_avg) {
490
+ v_out (i) = d != 0 ? n / d : 0 ;
491
+ } else {
492
+ v_out (i) = n;
493
+ }
429
494
});
430
495
} break ;
431
496
case 3 : {
432
497
auto v_in = f_in.get_view <const ST ***>();
498
+ auto v_m = is_masked ? f_in.get_header ().get_extra_data <Field>(" mask_data" ).get_view <const ST ***>() : v_in;
433
499
auto v_out = f_out.get_view <ST **>();
434
500
const int d0 = l_in.dim (0 );
435
501
const int d1 = l_in.dim (1 );
@@ -439,13 +505,21 @@ void vert_contraction(const Field &f_out, const Field &f_in, const Field &weight
439
505
const int idx = tm.league_rank ();
440
506
const int i = idx / d1;
441
507
const int j = idx % d1;
508
+ ST n = 0 , d = 0 ;
442
509
Kokkos::parallel_reduce (
443
510
Kokkos::TeamVectorRange (tm, nlevs),
444
- [&](int k, ST &ac) {
445
- ac += w_is_1d ? w1d (k) * v_in (i, j, k)
446
- : w2d (i, k) * v_in (i, j, k);
511
+ [&](int k, ST &n_acc, ST &d_acc) {
512
+ auto mask = is_masked ? v_m (i, j, k) : ST (1.0 );
513
+ auto w = w_is_1d ? w1d (k) : w2d (i, k);
514
+ n_acc += w * v_in (i, j, k) * mask;
515
+ d_acc += w * mask;
447
516
},
448
- v_out (i, j));
517
+ Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
518
+ if (is_avg) {
519
+ v_out (i, j) = d != 0 ? n / d : 0 ;
520
+ } else {
521
+ v_out (i, j) = n;
522
+ }
449
523
});
450
524
} break ;
451
525
default :
0 commit comments