@@ -28,13 +28,9 @@ void VertDerivativeDiag::set_grids(const std::shared_ptr<const GridsManager> gri
28
28
m_diag_name = fn + " _" + m_derivative_method + " vert_derivative" ;
29
29
30
30
auto scalar3d = g->get_3d_scalar_layout (true );
31
- if (m_derivative_method == " p" ) {
32
- add_field<Required>(" pseudo_density" , scalar3d, Pa, gn);
33
- } else if (m_derivative_method == " z" ) {
34
- add_field<Required>(" pseudo_density" , scalar3d, Pa, gn);
35
- add_field<Required>(" qv" , scalar3d, kg / kg, gn);
36
- add_field<Required>(" p_mid" , scalar3d, Pa, gn);
37
- add_field<Required>(" T_mid" , scalar3d, K, gn);
31
+ add_field<Required>(" pseudo_density" , scalar3d, Pa, gn);
32
+ if (m_derivative_method == " z" ) {
33
+ add_field<Required>(" dz" , scalar3d, m, gn);
38
34
}
39
35
}
40
36
@@ -49,24 +45,16 @@ void VertDerivativeDiag::initialize_impl(const RunType /*run_type*/) {
49
45
// TODO: support higher-dimensioned input fields
50
46
EKAT_REQUIRE_MSG (layout.rank () >= 2 && layout.rank () <= 2 ,
51
47
" Error! Field rank not supported by VertDerivativeDiag.\n "
52
- " - field name: " +
53
- fid.name () +
54
- " \n "
55
- " - field layout: " +
56
- layout.to_string () + " \n " );
48
+ " - field name: " + fid.name () + " \n "
49
+ " - field layout: " + layout.to_string () + " \n " );
57
50
EKAT_REQUIRE_MSG (layout.tags ().back () == LEV,
58
51
" Error! VertDerivativeDiag diagnostic expects a layout ending "
59
52
" with the 'LEV' tag.\n "
60
- " - field name : " +
61
- fid.name () +
62
- " \n "
63
- " - field layout: " +
64
- layout.to_string () + " \n " );
53
+ " - field name : " + fid.name () + " \n "
54
+ " - field layout: " + layout.to_string () + " \n " );
65
55
66
56
ekat::units::Units diag_units = fid.get_units ();
67
57
68
- m_denominator = get_field_in (" pseudo_density" ).clone (" denominator" );
69
-
70
58
if (m_derivative_method == " p" ) {
71
59
diag_units = fid.get_units () / Pa;
72
60
} else if (m_derivative_method == " z" ) {
@@ -94,46 +82,20 @@ void VertDerivativeDiag::compute_diagnostic_impl() {
94
82
using KT = KokkosTypes<DefaultDevice>;
95
83
using MT = typename KT::MemberType;
96
84
using TPF = ekat::TeamPolicyFactory<typename KT::ExeSpace>;
97
- const int ncols = m_denominator .get_header ().get_identifier ().get_layout ().dim (0 );
98
- const int nlevs = m_denominator .get_header ().get_identifier ().get_layout ().dim (1 );
85
+ const int ncols = f .get_header ().get_identifier ().get_layout ().dim (0 );
86
+ const int nlevs = f .get_header ().get_identifier ().get_layout ().dim (1 );
99
87
const auto policy = TPF::get_default_team_policy (ncols, nlevs);
100
88
101
- // get the denominator first
102
- if (m_derivative_method == " p" ) {
103
- m_denominator.update (dp, sp (1.0 ), sp (0.0 ));
104
- } else if (m_derivative_method == " dz" ) {
105
- // TODO: for some reason the z_mid field keeps getting set to 0
106
- // TODO: as a workaround, just calculate z_mid here (sigh...)
107
- // m_denominator.update(get_field_in("z_mid"), 1.0, 0.0);
108
- using PF = scream::PhysicsFunctions<DefaultDevice>;
109
- auto zm_v = m_denominator.get_view <Real **>();
110
- auto pm_v = get_field_in (" p_mid" ).get_view <const Real **>();
111
- auto tm_v = get_field_in (" T_mid" ).get_view <const Real **>();
112
- auto qv_v = get_field_in (" qv" ).get_view <const Real **>();
113
-
114
- Kokkos::parallel_for (
115
- " Compute dz for " + m_diagnostic_output.name (), policy, KOKKOS_LAMBDA (const MT &team) {
116
- const int icol = team.league_rank ();
117
- auto zm_icol = ekat::subview (zm_v, icol);
118
- auto dp_icol = ekat::subview (dp2d, icol);
119
- auto pm_icol = ekat::subview (pm_v, icol);
120
- auto tm_icol = ekat::subview (tm_v, icol);
121
- auto qv_icol = ekat::subview (qv_v, icol);
122
- PF::calculate_dz (team, dp_icol, pm_icol, tm_icol, qv_icol, zm_icol);
123
- });
124
- }
89
+ auto d_v = (m_derivative_method == " z" ) ? get_field_in (" dz" ).get_view <Real **>() : dp2d;
125
90
126
- auto d_v = m_denominator.get_view <Real **>();
127
91
Kokkos::parallel_for (
128
92
" Compute df / denominator for " + m_diagnostic_output.name (), policy,
129
93
KOKKOS_LAMBDA (const MT &team) {
130
94
const int icol = team.league_rank ();
131
95
auto f_icol = ekat::subview (f2d, icol); // field at midpoint
132
96
auto o_icol = ekat::subview (o2d, icol); // output at midnpoint
133
- auto d_icol =
134
- ekat::subview (d_v, icol); // recall denominator is already a difference of interfaces
135
- auto dpicol =
136
- ekat::subview (dp2d, icol); // in case of z deriv, d_icol and dpicol are not the same
97
+ auto d_icol = ekat::subview (d_v, icol); // recall denominator is already a difference of interfaces
98
+ auto dpicol = ekat::subview (dp2d, icol); // in case of z deriv, d_icol and dpicol are not the same
137
99
138
100
Kokkos::parallel_for (Kokkos::TeamVectorRange (team, nlevs), [&](const int ilev) {
139
101
// boundary points
0 commit comments