|
| 1 | +import os |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | +class CldFracNet(nn.Module): |
| 7 | + def __init__(self, input_size, output_size, neuron_count=64): |
| 8 | + super(CldFracNet, self).__init__() |
| 9 | + # emulate cld_ice = (qi > 1e-5) |
| 10 | + self.ice1 = nn.Linear(input_size, neuron_count) |
| 11 | + self.ice2 = nn.Linear(neuron_count, output_size) |
| 12 | + # emulate cld_tot = max(cld_ice, cld_liq) |
| 13 | + self.tot1 = nn.Linear(input_size*2, neuron_count) |
| 14 | + self.tot2 = nn.Linear(neuron_count, output_size) |
| 15 | + # a relu for fun |
| 16 | + self.relu = nn.ReLU() |
| 17 | + # sigmoid for categorical ice output |
| 18 | + self.sigmoid = nn.Sigmoid() |
| 19 | + |
| 20 | + def forward(self, qi, liq): |
| 21 | + # First, compute cld_ice from qi |
| 22 | + y11 = self.ice1(qi) |
| 23 | + y12 = self.relu(y11) |
| 24 | + y13 = self.ice2(y12) |
| 25 | + # Apply sigmoid to get probabilities |
| 26 | + y13_probabilities = self.sigmoid(y13) |
| 27 | + |
| 28 | + # During training, use straight-through estimator for gradients |
| 29 | + # During inference, use hard binary values |
| 30 | + if self.training: |
| 31 | + # Straight-through estimator: forward pass uses binary, backward pass uses sigmoid |
| 32 | + y13_binary = (y13_probabilities > 0.5).float() |
| 33 | + y13_categorical = y13_binary - y13_probabilities.detach() + y13_probabilities |
| 34 | + else: |
| 35 | + # During inference, use hard binary values |
| 36 | + y13_categorical = (y13_probabilities > 0.5).float() |
| 37 | + |
| 38 | + # Now compute cld_tot from cld_ice and cld_liq |
| 39 | + y21 = self.tot1(torch.cat((liq, y13_categorical), dim=1)) |
| 40 | + y22 = self.relu(y21) |
| 41 | + y23 = self.tot2(y22) |
| 42 | + return y13_categorical, y23 |
| 43 | + |
| 44 | +model = None |
| 45 | + |
| 46 | +def init (): |
| 47 | + global model |
| 48 | + |
| 49 | + # For this test, hard code nlevs, as well as pth file name/path |
| 50 | + nlevs = 72 |
| 51 | + current_file_directory = os.path.dirname(os.path.abspath(__file__)) |
| 52 | + model_file = f"{current_file_directory}/cldfrac_net_weights.pth" |
| 53 | + |
| 54 | + model = CldFracNet(nlevs,nlevs) |
| 55 | + model.load_state_dict(torch.load(model_file,map_location=torch.device('cpu'))) |
| 56 | + |
| 57 | +def main (ice_threshold, ice_4out_threshold, |
| 58 | + qi, liq_cld_frac, |
| 59 | + ice_cld_frac, tot_cld_frac, |
| 60 | + ice_cld_frac_4out, tot_cld_frac_4out): |
| 61 | + global model |
| 62 | + |
| 63 | + # Convert numpy inputs to torch arrays |
| 64 | + # Note: our pth model expects float32 arrays, so make sure we get the right dtype |
| 65 | + liq_pt = torch.tensor(liq_cld_frac, dtype=torch.float32) |
| 66 | + qi_pt = torch.tensor(qi, dtype=torch.float32) |
| 67 | + |
| 68 | + # Set model in evaluation mode |
| 69 | + model.eval() |
| 70 | + |
| 71 | + with torch.no_grad(): # Disable gradient for inference |
| 72 | + # Run the emulator |
| 73 | + ice_out, tot_out = model(qi_pt,liq_pt) |
| 74 | + |
| 75 | + # Update inout numpy arrays inplace |
| 76 | + ice_cld_frac[:] = ice_out.cpu().numpy() |
| 77 | + tot_cld_frac[:] = tot_out.cpu().numpy() |
0 commit comments