Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/burgers1d.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
lasdi:
type: gplasdi
gplasdi:
# device: mps
device: cuda
n_samples: 20
lr: 0.001
max_iter: 28000
n_iter: 2000
max_greedy_iter: 28000
max_greedy_iter: 28000
ld_weight: 0.1
coef_weight: 1.e-6
path_checkpoint: checkpoint
Expand Down Expand Up @@ -67,12 +67,21 @@ latent_space:
ae:
hidden_units: [100]
latent_dimension: 5
activation: softplus

latent_dynamics:
type: sindy
sindy:
fd_type: sbp12
coef_norm_order: fro
higher_order_terms: 1
extra_functions: []
#type: edmd
#edmd:
#fd_type: sbp12
# coef_norm_order: fro
# higher_order_terms: 0
# extra_functions: []

physics:
type: burgers1d
Expand Down
172 changes: 172 additions & 0 deletions src/lasdi/latent_dynamics/edmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import numpy as np
import torch
from scipy.integrate import odeint
from . import LatentDynamics
from ..inputs import InputParser
from ..fd import FDdict
from scipy.linalg import logm
import importlib

def get_function_from_string(func_str):
# Split the string into module and function
module_name, func_name = func_str.rsplit('.', 1)
module = importlib.import_module(module_name)
return getattr(module, func_name)

class EDMD(LatentDynamics):
fd_type = ''
fd = None
fd_oper = None

def __init__(self, dim, high_order_terms, rand_functions, nt, config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the usage shown below, high_order_terms and rand_functions can be parsed from config and need not be passed as input arguments. Can we parse them from config and reduce the number of input arguments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They could be. But updating the workflow maintains the original structure of defining trainers/latent space/latent physics. Can update it your way through if necessary!

super().__init__(dim, nt)

# Defining the higher order terms
self.high_order_terms = high_order_terms
self.rand_functions = rand_functions

# Number of coefficients depend upon the basis functions used
self.ncoefs = ((len(self.rand_functions) + self.high_order_terms)*self.dim + self.dim) ** 2

assert('edmd' in config)
parser = InputParser(config['edmd'], name='edmd_input')

'''
fd_type is the string that specifies finite-difference scheme for time derivative:
- 'sbp12': summation-by-parts 1st/2nd (boundary/interior) order operator
- 'sbp24': summation-by-parts 2nd/4th order operator
- 'sbp36': summation-by-parts 3rd/6th order operator
- 'sbp48': summation-by-parts 4th/8th order operator
'''
self.fd_type = parser.getInput(['fd_type'], fallback='sbp12')
self.fd = FDdict[self.fd_type]
self.fd_oper, _, _ = self.fd.getOperators(self.nt)

# NOTE(kevin): by default, this will be L1 norm.
self.coef_norm_order = parser.getInput(['coef_norm_order'], fallback=1)

# TODO(kevin): other loss functions
self.MSE = torch.nn.MSELoss()

return

def calibrate(self, Z, dt, compute_loss=True, numpy=False):
''' loop over all train cases, if Z dimension is 3 '''
if (Z.dim() == 3):
n_train = Z.size(0)

if (numpy):
coefs = np.zeros([n_train, self.ncoefs])
else:
coefs = torch.zeros([n_train, self.ncoefs])
loss_edmd, loss_coef = 0.0, 0.0

for i in range(n_train):
result = self.calibrate(Z[i], dt, compute_loss, numpy)
if (compute_loss):
coefs[i] = result[0]
loss_edmd += result[1]
loss_coef += result[2]
else:
coefs[i] = result

if (compute_loss):
return coefs, loss_edmd, loss_coef
else:
return coefs

''' evaluate for one train case '''
assert(Z.dim() == 2)

# Creating a copy!
Z_i = Z

# Adding higher order terms! Running a for loop based on how many higher order terms you want to add!
for i in range(self.high_order_terms):

# Append to the candidtate library the higher order expressions
Z_i = torch.cat([Z_i, Z**(i+2)], dim = 1)

# Adding trignometric functions
for i in self.rand_functions:
Z_i = torch.cat([Z_i, get_function_from_string(i)(Z)], dim = 1)

# reshaping the Z to have columns as snapshots!
Z_i = torch.transpose(Z_i,0,1)

# Get the Z' matrix!
Z_plus = Z_i[:,1:]
Z_minus = Z_i[:,0:-1]

# Get the A operator: Using lstsq since that is more stable then pseudo inverse!
A = (torch.linalg.lstsq(Z_minus.T,Z_plus.T).solution).T
#A = Z_plus @ torch.linalg.pinv(Z_minus)

# Compute the losses!
if (compute_loss):

# NOTE(khushant): This loss is different from what is used in SINDy.
loss_edmd = self.MSE(Z_plus, A @ Z_minus)

# NOTE(kevin): by default, this will be L1 norm.
loss_coef = torch.norm(A, self.coef_norm_order)

# output of lstsq is not contiguous in memory.
coefs = A.detach().flatten()
if (numpy):
coefs = coefs.numpy()

if (compute_loss):
return coefs, loss_edmd, loss_coef
else:
return coefs

def simulate(self, coefs, z0, t_grid):

'''

Integrates each system of ODEs corresponding to each training points, given the initial condition Z0 = encoder(U0)

'''
# copy is inevitable for numpy==1.26. removed copy=False temporarily.
A = coefs.reshape([(self.high_order_terms + len(self.rand_functions)) * self.dim + self.dim, (self.high_order_terms + len(self.rand_functions)) * self.dim + self.dim])

Z_i = np.zeros((len(t_grid), self.dim))
Z_i[0,:] = z0

# Performing the integration

for i in range(1,len(t_grid)):

# Making a copy of the z
z_new = Z_i[i-1,:]

# Add the higher order terms!
for j in range(self.high_order_terms):

# Get the higher order terms if any
new_terms = np.power(Z_i[i-1,:],j+2)

# Stack the initial conditions!
z_new = np.hstack((z_new,new_terms))

# Add the trignometric funtions to the candidate library!
for k in self.rand_functions:

# Get the trig terms!
new_terms = get_function_from_string(k)(torch.from_numpy(Z_i[i-1,:]))

# Stack the initial conditions!
z_new = np.hstack((z_new, new_terms.detach().cpu().numpy()))

# Integrate and store!
Z_i[i,:] = (A @ z_new)[:self.dim]

return Z_i

def export(self):
param_dict = super().export()
param_dict['fd_type'] = self.fd_type
param_dict['coef_norm_order'] = self.coef_norm_order
return param_dict

11 changes: 8 additions & 3 deletions src/lasdi/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .gplasdi import BayesianGLaSDI
from .latent_space import Autoencoder
from .latent_dynamics.sindy import SINDy
from .latent_dynamics.edmd import EDMD
from .physics.burgers1d import Burgers1D
from .param import ParameterSpace
from .inputs import InputParser
Expand All @@ -16,7 +17,7 @@

latent_dict = {'ae': Autoencoder}

ld_dict = {'sindy': SINDy}
ld_dict = {'edmd': EDMD}

physics_dict = {'burgers1d': Burgers1D}

Expand Down Expand Up @@ -151,14 +152,18 @@ def initialize_trainer(config, restart_file=None):

physics = initialize_physics(config, param_space.param_name)
latent_space = initialize_latent_space(physics, config)

if (restart_file is not None):
latent_space.load(restart_file['latent_space'])

# do we need a separate routine for latent dynamics initialization?
ld_type = config['latent_dynamics']['type']
assert(ld_type in config['latent_dynamics'])
assert(ld_type in ld_dict)
latent_dynamics = ld_dict[ld_type](latent_space.n_z, physics.nt, config['latent_dynamics'])

# Updating the dynamics callback to account for the higher order terms and other non linear functions from the .yml file.

latent_dynamics = ld_dict[ld_type](latent_space.n_z, config['latent_dynamics'][ld_type]['higher_order_terms'], config['latent_dynamics'][ld_type]['extra_functions'], physics.nt, config['latent_dynamics'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned above, high_order_terms and rand_functions can be parsed from config and need not be passed as input arguments. That way we wouldn't have to change this line.

if (restart_file is not None):
latent_dynamics.load(restart_file['latent_dynamics'])

Expand Down Expand Up @@ -329,4 +334,4 @@ def collect_samples(trainer, config):
return result, next_step

if __name__ == "__main__":
main()
main()