Skip to content

adding dof manager to params and re-writing params as an equinox module #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/hole_array/hole_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
]

dofManager = DofManager(func_space, 2, ebcs)

print(dofManager)
props = {'elastic modulus': 3. * 10.0 * (1. - 2. * 0.3),
'poisson ratio': 0.3,
'version': 'coupled'}
Expand All @@ -53,11 +53,11 @@ def get_ubcs(p):
V = np.zeros(mesh.coords.shape)
index = (mesh.nodeSets['yplus_nodeset'], 1)
V = V.at[index].set(yLoc)
return dofManager.get_bc_values(V)
return p.dof_manager.get_bc_values(V)


def create_field(Uu, p):
return dofManager.create_field(Uu, get_ubcs(p))
return p.dof_manager.create_field(Uu, get_ubcs(p))


def energy_function(Uu, p):
Expand Down Expand Up @@ -120,7 +120,7 @@ def run():
Uu = dofManager.get_unknown_values(np.zeros(mesh.coords.shape))
disp = 0.0
ivs = mech_funcs.compute_initial_state()
p = Objective.Params(disp, ivs)
p = Objective.Params(disp, ivs, dof_manager=dofManager)
precond_strategy = Objective.PrecondStrategy(assemble_sparse)
objective = Objective.Objective(energy_function, Uu, p, precond_strategy)

Expand Down
98 changes: 76 additions & 22 deletions optimism/Objective.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,89 @@
from optimism.FunctionSpace import DofManager
from optimism.JaxConfig import *
from optimism.SparseCholesky import SparseCholesky
import numpy as onp
from scipy.sparse import diags as sparse_diags
from scipy.sparse import csc_matrix
from typing import Optional
import equinox as eqx
import numpy as onp


# static vs dynamics
# differentiable vs undifferentiable
Params = namedtuple('Params',
['bc_data',
'state_data',
'design_data',
'app_data',
'time',
'dynamic_data'],
defaults=(None,None,None,None,None,None))
# TODO fix some of these type hints for better clarity.
# maybe this will help formalize what's what when
class Params(eqx.Module):
bc_data: any
state_data: any
design_data: any
app_data: any
time: any
dynamic_data: any
# Need the eqx.field(static=True) since DofManager
# is composed of mainly og numpy arrays which leads
# to the error
#jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[x,x])
dof_manager: DofManager = eqx.field(static=True)

def __init__(
self,
bc_data = None,
state_data = None,
design_data = None,
app_data = None,
time = None,
dynamic_data = None,
dof_manager: Optional[DofManager] = None
):
self.bc_data = bc_data
self.state_data = state_data
self.design_data = design_data
self.app_data = app_data
self.time = time
self.dynamic_data = dynamic_data
self.dof_manager = dof_manager

def __getitem__(self, index):
if index == 0:
return self.bc_data
elif index == 1:
return self.state_data
elif index == 2:
return self.design_data

Check warning on line 52 in optimism/Objective.py

View check run for this annotation

Codecov / codecov/patch

optimism/Objective.py#L52

Added line #L52 was not covered by tests
elif index == 3:
return self.app_data

Check warning on line 54 in optimism/Objective.py

View check run for this annotation

Codecov / codecov/patch

optimism/Objective.py#L54

Added line #L54 was not covered by tests
elif index == 4:
return self.time
elif index == 5:
return self.dynamic_data
elif index == 6:
return self.dof_manager

Check warning on line 60 in optimism/Objective.py

View check run for this annotation

Codecov / codecov/patch

optimism/Objective.py#L57-L60

Added lines #L57 - L60 were not covered by tests
else:
raise ValueError(f'Bad index value {index}')

Check warning on line 62 in optimism/Objective.py

View check run for this annotation

Codecov / codecov/patch

optimism/Objective.py#L62

Added line #L62 was not covered by tests


# written for backwards compatability
# we can just use the eqx.tree_at syntax in simulations
# or we could write a single method bound to Params for this...
def param_index_update(p, index, newParam):
if index==0:
return Params(newParam, p[1], p[2], p[3], p[4], p[5])
if index==1:
return Params(p[0], newParam, p[2], p[3], p[4], p[5])
if index==2:
return Params(p[0], p[1], newParam, p[3], p[4], p[5])
if index==3:
return Params(p[0], p[1], p[2], newParam, p[4], p[5])
if index==4:
return Params(p[0], p[1], p[2], p[3], newParam, p[5])
if index==5:
return Params(p[0], p[1], p[2], p[3], p[4], newParam)
print('invalid index passed to param_index_update = ', index)
if index == 0:
p = eqx.tree_at(lambda x: x.bc_data, p, newParam)
elif index == 1:
p = eqx.tree_at(lambda x: x.state_data, p, newParam)
elif index == 2:
p = eqx.tree_at(lambda x: x.design_data, p, newParam)
elif index == 3:
p = eqx.tree_at(lambda x: x.app_data, p, newParam)
elif index == 4:
p = eqx.tree_at(lambda x: x.time, p, newParam)
elif index == 5:
p = eqx.tree_at(lambda x: x.dynamic_data, p, newParam)
elif index == 6:
p = eqx.tree_at(lambda x: x.dof_manager, p, newParam)

Check warning on line 82 in optimism/Objective.py

View check run for this annotation

Codecov / codecov/patch

optimism/Objective.py#L81-L82

Added lines #L81 - L82 were not covered by tests
else:
raise ValueError(f'Bad index value {index}')

Check warning on line 84 in optimism/Objective.py

View check run for this annotation

Codecov / codecov/patch

optimism/Objective.py#L84

Added line #L84 was not covered by tests

return p


class PrecondStrategy:
Expand Down