Skip to content

Conversation

mahf708
Copy link
Contributor

@mahf708 mahf708 commented Jul 27, 2025

allows vert and horiz contractions to take masked fields and process them accordingly. This relies on the input fields having mask_data field to zero out masked entries. For the average specialization, a renormalization of the weighting is in order.

--

[BFB]

@mahf708 mahf708 requested review from bartgol and tcclevenger July 27, 2025 23:49
@mahf708 mahf708 added the EAMxx Issues related to EAMxx label Jul 27, 2025
@mahf708 mahf708 force-pushed the mahf708/eamxx/contract-masks branch from 606d688 to f1dd1b0 Compare July 28, 2025 01:03
@mahf708 mahf708 marked this pull request as draft July 28, 2025 04:31
@mahf708 mahf708 force-pushed the mahf708/eamxx/contract-masks branch from 2a3393a to 34e7697 Compare July 28, 2025 13:18
@mahf708 mahf708 marked this pull request as ready for review July 28, 2025 13:18
l_out.size(), MPI_SUM);
f_tmp.sync_to_dev();

// update f_out by dividing it with f_tmp
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to use f_out.inv_scale(f_tmp) here but that has a smart guard against non-const fields (even though we often get fields as const, but modify their views), and I didn't want to accommodate that throughout the chain...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what guard you are referring to. That said, you can't use that fcn b/c that would divide by f_tmp even where it is 0, while here it seems you want to set f_out to 0 wherever f_tmp=0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking of, why setting it to 0? In I/O, 0 is a "valid" number, which will be considered when doing time averages (for instance). Perhaps we should set fill_value?

This comment was marked as outdated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, recall, the reason I am being flexible with whatever the user gives me in terms of weights is that I want the user to scale these contractions as they see fit (say sum vs avg, vs something unhinged that I cannot think of). For cases, where we are in control (like horiz_avg or vert_avg, etc.), I want the contraction to return precisely nanmean/nanaverage in numpy-speak, and other users wanting to specify stuff can just use the bareback impl do sum whatever they want

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking of, why setting it to 0? In I/O, 0 is a "valid" number, which will be considered when doing time averages (for instance). Perhaps we should set fill_value?

you are right, this needs to be fill_value

Comment on lines +122 to +125
template <typename ST, bool AVG = 1>
void horiz_contraction(const Field &f_out, const Field &f_in,
const Field &weight, const ekat::Comm *comm = nullptr) {
const Field &weight, const ekat::Comm *comm = nullptr,
const Field &f_tmp = Field()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still looking through to try and understand the changes, but I think comments here on what AVG option is and what f_tmp should be would be helpful.

Also, I don't like having a input field being called "tmp", since I associate that with some internal variable that isn't user facing, but this may be where I need to understand the impl a little more.

// - rank-2, with only COL and LEV/ILEV dimensions
// NOTE: we assume the LEV/ILEV dimension is NOT partitioned.
template <typename ST>
template <typename ST, int AVG = 0>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is AVG defaulted on for horizontal, but off for vertical?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point (and on purpose). The reason is that for the horizontal one, our primary support is for weighted averaging. For the vertical, we are agnostic to the application (and it doesn't matter which one we default to, so I left it at 0). I can force users to specify both or default both to 0 or 1. What would you recommend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since AVG can be 0 or 1, can we instead use a boolean, for clarity?

Comment on lines 521 to 525
Kokkos::deep_copy(v_out, n);
if (is_avg_masked) {
ST tmp = d != 0 ? n / d : 0;
Kokkos::deep_copy(v_out, tmp);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid an unnecessary deep copy in the case of average and masked.

Suggested change
Kokkos::deep_copy(v_out, n);
if (is_avg_masked) {
ST tmp = d != 0 ? n / d : 0;
Kokkos::deep_copy(v_out, tmp);
}
if (is_avg_masked) {
ST tmp = d != 0 ? n / d : 0;
Kokkos::deep_copy(v_out, tmp);
} else {
Kokkos::deep_copy(v_out, n);
}

Comment on lines +219 to +224
// if f has a mask and we are averaging, need to call the avg specialization
if (m_contract_method == "avg" && f.get_header().has_extra_data("mask_data")) {
vert_contraction<Real,1>(d, f, m_weighting);
} else {
vert_contraction<Real,0>(d, f, m_weighting);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should vert_contraction() care about averaging at all? Above, in the case of m_contract_method == "avg", aren't you computing the average of the weights that then get applied here? Then with vert_contraction<AVG=1> you are again averaging, even though the weights already represent an average?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but only in the case of averaging masked quantities. The 0 or 1 is not really about AVG per se, it is about renormalizing the weights after taking out the entries corresponding to fill-values see below:

lev T_mid qc
1 200 .1
2 300 .2
3 310 _
4 300 .1

For simplicitly, let's say we are using equal weighting, i.e., 1/4 (0.25) for each.

In the case of T_mid (no masking), it won't matter if we renormalize or not: 0.25 * (200+300+310+300) = 0.25 (200+300+310+300) / 1

In the case of qc (masking) it does matter if we renormalize or not: 0.25 * (0.1 + 0.2 + 0.1) != 0.25 * (0.1 + 0.2 + 0.1) / 0.75

Maybe I should rename the thing in the template to something better? AVG is really misleading...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should rename the thing in the template to something better? AVG is really misleading...

Yes, please! I also was confused by this. I suppose you could call it RenormalizeAfterMasking or something like that. I always prefer very_verbose_and_long_names_but_very_clear over cryptic/misleading ones...

Other ideas: UseOnlyNonMaskedWeights, AlsoMaskWeights, NormalizeWeightsWithMask,...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mahf708, I understand now. And yes, I also agree with changing the name.

Copy link
Contributor

@bartgol bartgol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that, in case of mask_data present, you want to be able to NOT consider the weights corresponding to fill-val entries of f_in when computing the weight contraction? If so, why not always make that the case?

// - The first dimension is for the columns (COL)
// - There can be only up to 3 dimensions of f_in
template <typename ST>
template <typename ST, bool AVG = 1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, 1 means "true" and 0 means "false", but since you are using bool as type, why not using the slightly more verbose/self-explanatory value true instead of 1?

Comment on lines +219 to +224
// if f has a mask and we are averaging, need to call the avg specialization
if (m_contract_method == "avg" && f.get_header().has_extra_data("mask_data")) {
vert_contraction<Real,1>(d, f, m_weighting);
} else {
vert_contraction<Real,0>(d, f, m_weighting);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should rename the thing in the template to something better? AVG is really misleading...

Yes, please! I also was confused by this. I suppose you could call it RenormalizeAfterMasking or something like that. I always prefer very_verbose_and_long_names_but_very_clear over cryptic/misleading ones...

Other ideas: UseOnlyNonMaskedWeights, AlsoMaskWeights, NormalizeWeightsWithMask,...

// - rank-2, with only COL and LEV/ILEV dimensions
// NOTE: we assume the LEV/ILEV dimension is NOT partitioned.
template <typename ST>
template <typename ST, int AVG = 0>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since AVG can be 0 or 1, can we instead use a boolean, for clarity?

d_acc += v_w(i) * mask;
},
Kokkos::Sum<ST>(n), Kokkos::Sum<ST>(d));
Kokkos::deep_copy(v_out, n);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you should move this inside the first branch of the if, to avoid a pointless deep_copy in case the else if triggers?

l_out.size(), MPI_SUM);
f_tmp.sync_to_dev();

// update f_out by dividing it with f_tmp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what guard you are referring to. That said, you can't use that fcn b/c that would divide by f_tmp even where it is 0, while here it seems you want to set f_out to 0 wherever f_tmp=0.

l_out.size(), MPI_SUM);
f_tmp.sync_to_dev();

// update f_out by dividing it with f_tmp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking of, why setting it to 0? In I/O, 0 is a "valid" number, which will be considered when doing time averages (for instance). Perhaps we should set fill_value?

f_out.sync_to_dev();
if (is_comm_avg_masked) {
f_tmp.sync_to_host();
comm->all_reduce(f_tmp.template get_internal_view_data<ST, Host>(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be aware that get_internal_view_data is a power-user method. In particular, it is not safe to use with fields that are subfields of another field (where their entries may not be contiguous in memory). Maybe you could add a check at the fcn top (among all others checks on f_tmp) to ensure it's not a subfield?

TBC, I don't think anyone would ever pass a subfield as a scratch, but better be safe than sorry..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an alternative here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. I mean, we could add some utility, which defines an ad-hoc derived MPI_DataType, which sets the correct strides between data blocks (something like MPI_Type_vector). But until we see the need for this, I think we should simply check that fields are not in fact sub-fields and be done with it...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be precise, we can support subfields as long as they are subfields along the slowest striding dim, since once we extract the view, we get contiguous data, but it would require a few if/else (basically extract the view, then take the view pointer). The method get_internal_view_data does not do "get view, then take the view pointer", but it simply returns the internal view pointer, which, for subfields, is the pointer to the global "super"-field. Boring impl details.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline docs for the method explain the risk:

  // WARNING: this is a power-user method. Its implementation, including assumptions
  //          on pre/post conditions, may change in the future. Use at your own risk!
  //          Read carefully the instructions below.
  // Allows to get a raw pointer (host or device) from the view stored in the field.
  // The user must provide the pointed type for the returned pointer. Such type must
  // either be char or the correct data type of this field.
  // Notice that the view stored may contain more data than just the
  // data of the current field. This can happen in two cases (possibly simultaneously).
  //   - The field was allocated in a way that allows packing. In this case,
  //     there may be padding along the last *physical* dimension of the field.
  //     In the stored 1d view, the padding may appear interleaved with actual
  //     data (due to the view layout being LayoutRight).
  //   - The field is a subfield of another field. In this case, this class
  //     actually stores the "parent" field view. So when calling this method,
  //     you will actually get the raw pointer of the parent field view.

@mahf708
Copy link
Contributor Author

mahf708 commented Jul 28, 2025

Do I understand correctly that, in case of mask_data present, you want to be able to NOT consider the weights corresponding to fill-val entries of f_in when computing the weight contraction? If so, why not always make that the case?

It adds extra work (e.g., in horiz_contraction, I couldn't figure out a way to do the normalization after weighting without a second dummy field), and we are not guaranteed to have mask_data present, so we have to branch. Were you thinking of something specific I am missing? Please help me improve the impl if possible :D

@bartgol
Copy link
Contributor

bartgol commented Jul 28, 2025

Do I understand correctly that, in case of mask_data present, you want to be able to NOT consider the weights corresponding to fill-val entries of f_in when computing the weight contraction? If so, why not always make that the case?

It adds extra work (e.g., in horiz_contraction, I couldn't figure out a way to do the normalization after weighting without a second dummy field), and we are not guaranteed to have mask_data present, so we have to branch. Were you thinking of something specific I am missing? Please help me improve the impl if possible :D

What I mean is: does the AVG template arg determine whether masked entries should be filtered out from the weight field sum, to basically normalize the avg of F with the sum of W only where F!=fill_val? If so, why not do it all the time?

As for the temporary, I don't think you have many alternatives.

@mahf708 mahf708 marked this pull request as draft August 5, 2025 14:39
@mahf708
Copy link
Contributor Author

mahf708 commented Aug 7, 2025

I don't want to integrate this separately. I will just do it as part of #7508

@mahf708 mahf708 closed this Aug 7, 2025
@mahf708 mahf708 deleted the mahf708/eamxx/contract-masks branch August 7, 2025 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
EAMxx Issues related to EAMxx
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants