Skip to content

Commit 269e89b

Browse files
committed
EAMxx: add a pytorch ml emulator test for cld_frac
Note: the emulator is NOT good, it just showcases the capability
1 parent 4b5c9e2 commit 269e89b

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

components/eamxx/tests/single-process/cld_fraction/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,18 @@ if (EAMXX_ENABLE_PYTHON)
4949
LABELS "cldfrac;infrastructure"
5050
FIXTURES_REQUIRED "cldfrac_py;cldfrac_cpp")
5151

52+
# Run an ml emulator for cld-fraction
53+
set (PY_MODULE_NAME "cld_fraction_ml")
54+
set (PY_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR})
55+
set (POSTFIX pyml)
56+
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/input.yaml
57+
${CMAKE_CURRENT_BINARY_DIR}/input_pyml.yaml)
58+
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/output.yaml
59+
${CMAKE_CURRENT_BINARY_DIR}/output_pyml.yaml)
60+
61+
# Test the process with python ml emulator
62+
CreateUnitTestFromExec(cld_fraction_standalone_pyml cld_fraction_standalone
63+
EXE_ARGS "--args -ifile=input_pyml.yaml"
64+
LABELS cld_fraction physics
65+
FIXTURES_SETUP cldfrac_pyml)
5266
endif()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
class CldFracNet(nn.Module):
7+
def __init__(self, input_size, output_size, neuron_count=64):
8+
super(CldFracNet, self).__init__()
9+
# emulate cld_ice = (qi > 1e-5)
10+
self.ice1 = nn.Linear(input_size, neuron_count)
11+
self.ice2 = nn.Linear(neuron_count, output_size)
12+
# emulate cld_tot = max(cld_ice, cld_liq)
13+
self.tot1 = nn.Linear(input_size*2, neuron_count)
14+
self.tot2 = nn.Linear(neuron_count, output_size)
15+
# a relu for fun
16+
self.relu = nn.ReLU()
17+
# sigmoid for categorical ice output
18+
self.sigmoid = nn.Sigmoid()
19+
20+
def forward(self, qi, liq):
21+
# First, compute cld_ice from qi
22+
y11 = self.ice1(qi)
23+
y12 = self.relu(y11)
24+
y13 = self.ice2(y12)
25+
# Apply sigmoid to get probabilities
26+
y13_probabilities = self.sigmoid(y13)
27+
28+
# During training, use straight-through estimator for gradients
29+
# During inference, use hard binary values
30+
if self.training:
31+
# Straight-through estimator: forward pass uses binary, backward pass uses sigmoid
32+
y13_binary = (y13_probabilities > 0.5).float()
33+
y13_categorical = y13_binary - y13_probabilities.detach() + y13_probabilities
34+
else:
35+
# During inference, use hard binary values
36+
y13_categorical = (y13_probabilities > 0.5).float()
37+
38+
# Now compute cld_tot from cld_ice and cld_liq
39+
y21 = self.tot1(torch.cat((liq, y13_categorical), dim=1))
40+
y22 = self.relu(y21)
41+
y23 = self.tot2(y22)
42+
return y13_categorical, y23
43+
44+
model = None
45+
46+
def init ():
47+
global model
48+
49+
# For this test, hard code nlevs, as well as pth file name/path
50+
nlevs = 72
51+
current_file_directory = os.path.dirname(os.path.abspath(__file__))
52+
model_file = f"{current_file_directory}/cldfrac_net_weights.pth"
53+
54+
model = CldFracNet(nlevs,nlevs)
55+
model.load_state_dict(torch.load(model_file,map_location=torch.device('cpu')))
56+
57+
def main (ice_threshold, ice_4out_threshold,
58+
qi, liq_cld_frac,
59+
ice_cld_frac, tot_cld_frac,
60+
ice_cld_frac_4out, tot_cld_frac_4out):
61+
global model
62+
63+
# Convert numpy inputs to torch arrays
64+
# Note: our pth model expects float32 arrays, so make sure we get the right dtype
65+
liq_pt = torch.tensor(liq_cld_frac, dtype=torch.float32)
66+
qi_pt = torch.tensor(qi, dtype=torch.float32)
67+
68+
# Set model in evaluation mode
69+
model.eval()
70+
71+
with torch.no_grad(): # Disable gradient for inference
72+
# Run the emulator
73+
ice_out, tot_out = model(qi_pt,liq_pt)
74+
75+
# Update inout numpy arrays inplace
76+
ice_cld_frac[:] = ice_out.cpu().numpy()
77+
tot_cld_frac[:] = tot_out.cpu().numpy()
94.2 KB
Binary file not shown.

0 commit comments

Comments
 (0)