Skip to content

Implements MPCs #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
103 changes: 100 additions & 3 deletions optimism/FunctionSpace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from optimism import Interpolants
from optimism import Mesh
from optimism import QuadratureRule
from typing import Tuple
from typing import Tuple, List
import equinox as eqx
import jax
import jax.numpy as np
import numpy as onp
import scipy.sparse as sp


class EssentialBC(eqx.Module):
Expand Down Expand Up @@ -362,7 +363,6 @@ def integrate_function_on_edges(functionSpace, func, U, quadRule, edges):
integrate_on_edges = jax.vmap(integrate_function_on_edge, (None, None, None, None, 0))
return np.sum(integrate_on_edges(functionSpace, func, U, quadRule, edges))


class DofManager(eqx.Module):
# TODO get type hints below correct
# TODO this one could be moved to jax types if we move towards
Expand All @@ -387,7 +387,7 @@ def __init__(self, functionSpace, dim, EssentialBCs):
self.isUnknown = ~self.isBc

self.ids = onp.arange(self.isBc.size).reshape(self.fieldShape)

print(self.ids.shape)
self.unknownIndices = self.ids[self.isUnknown]
self.bcIndices = self.ids[self.isBc]

Expand Down Expand Up @@ -466,3 +466,100 @@ def _make_hessian_bc_mask(self, conns):
hessian_bc_mask[e,eFlag,:] = False
hessian_bc_mask[e,:,eFlag] = False
return hessian_bc_mask

# DofManager for Multi-Point Constrained Problem
class DofManagerMPC(DofManager):
dim: int = eqx.field()
master_slave_pairs: dict = eqx.field()
C: np.ndarray = eqx.field()
C_reduced: np.ndarray = eqx.field()
s_reduced: np.ndarray = eqx.field()
s_tilde: np.ndarray = eqx.field()
T: np.ndarray = eqx.field()
isIndependent: np.ndarray = eqx.field()
isUnconstrained: np.ndarray = eqx.field()
is_slave: np.ndarray = eqx.field()
is_indep_dof: np.ndarray = eqx.field()
unconstrainedIndices: np.ndarray = eqx.field()
slaveIndices: np.ndarray = eqx.field()
bcIndices: np.ndarray = eqx.field()
dofToUnknown: np.ndarray = eqx.field()

def __init__(self, functionSpace, dim, EssentialBCs, master_slave_pairs, C, s):
super().__init__(functionSpace, dim, EssentialBCs)
self.fieldShape = Mesh.num_nodes(functionSpace.mesh), dim
self.dim = dim
self.master_slave_pairs = master_slave_pairs
self.C = np.array(C)
self.s_tilde = np.array(s)
self.C_reduced = np.array(C) # Reduced constraint matrix
self.s_reduced = np.array(s) # Reduced shift vector
self.isIndependent = None
self.isUnconstrained = None
self.is_indep_dof = None
self.is_slave = None
self.slaveIndices = None
self.unconstrainedIndices = None
self.create_mpc_transformation()

def create_mpc_transformation(self):
slave_nodes = np.array(list(self.master_slave_pairs.keys()))
master_nodes = np.array(list(self.master_slave_pairs.values()))
is_slave = onp.full(self.fieldShape, False, dtype=bool)
# self.is_slave = is_slave.at[slave_nodes, :].set(True)

self.is_slave = onp.full(self.fieldShape, False, dtype=bool)
self.is_slave[slave_nodes,:] = True

# self.is_slave = is_slave
self.isUnconstrained = ~self.is_slave

# print("shape of isUnknown: ", self.isUnknown.shape)

self.ids = onp.arange(self.is_slave.size).reshape(self.fieldShape)
print(self.ids.shape)
self.unconstrainedIndices = self.ids[self.isUnconstrained]
self.slaveIndices = self.ids[self.is_slave]

T = np.zeros((self.is_slave.size, self.unconstrainedIndices.size))

for local_idx, global_idx in enumerate(self.unconstrainedIndices):
T = T.at[global_idx, local_idx].set(1.0)

# Build s_tilde: global shift
s_tilde = np.zeros(self.is_slave.size)

for i, (slave_node, master_node) in enumerate(self.master_slave_pairs.items()):
for d in range(self.dim):
slave_dof = slave_node * self.dim + d
master_dof = master_node * self.dim + d


# Find column index in reduced DOFs
reduced_master_index = np.where(self.unconstrainedIndices == master_dof)[0]
if reduced_master_index.size == 0:
raise ValueError(f"Master DOF {master_dof} is not independent")

T = T.at[slave_dof, reduced_master_index[0]].set(1.0)
s_tilde = s_tilde.at[slave_dof].set(self.s_reduced[i * self.dim + d])

self.T = T
self.s_tilde = s_tilde

# Track dof-to-unknown mapping
ones = onp.ones(self.is_slave.size, dtype=int) * -1
dofToUnknown = ones
dofToUnknown[self.unconstrainedIndices] = onp.arange(self.unconstrainedIndices.size)
self.dofToUnknown = dofToUnknown

def create_field(self, Uu, Ubc=0.0):
U_flat = np.matmul(self.T, Uu) + self.s_tilde
return U_flat.reshape(self.fieldShape)

def get_unknown_values(self, U):
print("shape of U in get unconstrained values: ", U.shape)
print("shape of isUnconstrained: ", self.isUnconstrained.shape)
return U[self.isUnknown]

def get_slave_values(self, U):
return U[self.is_slave]
80 changes: 78 additions & 2 deletions optimism/Objective.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from optimism.JaxConfig import *
from optimism.SparseCholesky import SparseCholesky
import numpy as onp
from scipy.sparse import diags as sparse_diags
import jax.numpy as np
import jax
from jax import jit, grad, jacfwd, jvp, vjp
from scipy.sparse import csc_matrix
from scipy.sparse import diags as sparse_diags


# static vs dynamics
# differentiable vs undifferentiable
Expand Down Expand Up @@ -265,4 +269,76 @@ def get_value(self, x):
def get_residual(self, x):
return self.gradient(self.scaling * x)


# Objective class for handling Multi-Point Constraints
class ObjectiveMPC(Objective):
def __init__(self, objective_func, x_full, p, dofManagerMPC, precondStrategy=None):
"""
ObjectiveMPC: wraps the original energy functional to solve only for reduced DOFs.
- Applies u = T * ũ + s̃ for MPC condensation.
- Automatically condenses gradient and Hessian.
"""
self.dofManagerMPC = dofManagerMPC
self.T = dofManagerMPC.T
self.s_tilde = dofManagerMPC.s_tilde
self.p = p
self.precondStrategy = precondStrategy

# Store full functional and derivatives
self.full_objective_func = jit(objective_func)
self.full_grad_x = jit(grad(objective_func, 0))
self.full_hess = jit(jacfwd(self.full_grad_x, 0))

self.scaling = 1.0
self.invScaling = 1.0

# Reduced initial guess from full vector
self.x_reduced0 = self.reduce_to_independent_dofs(x_full)

# Define reduced-space objective and gradient
self.objective = jit(lambda x_red, p: self.full_objective_func(self.expand_to_full_dofs(x_red), p))
self.grad_x = jit(lambda x_red, p: self.reduce_gradient(self.full_grad_x(self.expand_to_full_dofs(x_red), p)))

self.precond = SparseCholesky()

def reduce_to_independent_dofs(self, x_full):
"""Extract reduced DOFs from full vector."""
return self.dofManagerMPC.get_unknown_values(x_full)

def expand_to_full_dofs(self, x_reduced):
"""Reconstruct full DOF vector using u = T ũ + s̃."""
return self.dofManagerMPC.create_field(x_reduced)

def reduce_gradient(self, grad_full):
"""Project full gradient to reduced space."""
return self.dofManagerMPC.get_unknown_values(grad_full)

def value(self, x_reduced):
"""Return energy functional in reduced space."""
return self.objective(x_reduced, self.p)

def gradient(self, x_reduced):
"""Return reduced gradient."""
return self.grad_x(x_reduced, self.p)

def hessian(self, x_reduced):
"""Compute reduced Hessian H̃ = Tᵀ H T."""
x_full = self.expand_to_full_dofs(self.scaling * x_reduced)
H_full = self.full_hess(x_full, self.p)
H_reduced = np.matmul(self.T.T, np.matmul(H_full, self.T))
return H_reduced

def update_precond(self, x_reduced):
"""Update preconditioner from reduced Hessian."""
print("Updating with condensed Hessian preconditioner.")
H_reduced = csc_matrix(self.hessian(x_reduced))
self.precond.update(lambda attempt: H_reduced)

def apply_precond(self, vx):
return self.precond.apply(vx) if self.precond else vx

def multiply_by_approx_hessian(self, vx):
return self.precond.multiply_by_approximate(vx) if self.precond else vx

def check_stability(self, x_reduced):
if self.precond:
self.precond.check_stability(x_reduced, self.p)
88 changes: 85 additions & 3 deletions optimism/test/MeshFixture.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# fea data structures
# ---------------- NEW MESHFIXTURE SCRIPT --------------------------
# ------------------------------------------------------------------
# Note - This script involves a function creating nodeset layers
# ------------------------------------------------------------------

import jax.numpy as jnp
import numpy as onp
from optimism.Mesh import *
from optimism import Surface

# testing utils
from .TestFixture import *
from TestFixture import *

d_kappa = 1.0
d_nu = 0.3
Expand Down Expand Up @@ -139,4 +145,80 @@ def is_edge_on_left(xyOnEdge):
blocks = {'block': np.arange(conns.shape[0])}
U = np.zeros(coords.shape)
return construct_mesh_from_basic_data(coords, conns, None, nodeSets, sideSets), U


def create_mesh_disp_and_nodeset_layers(self, Nx, Ny, xRange, yRange, initial_disp_func, setNamePostFix=''):
coords, conns = create_structured_mesh_data(Nx, Ny, xRange, yRange)
tol = 1e-8
nodeSets = {}

# Predefined boundary node sets
nodeSets['left'+setNamePostFix] = jnp.flatnonzero(coords[:, 0] < xRange[0] + tol)
nodeSets['bottom'+setNamePostFix] = jnp.flatnonzero(coords[:, 1] < yRange[0] + tol)
nodeSets['right'+setNamePostFix] = jnp.flatnonzero(coords[:, 0] > xRange[1] - tol)
nodeSets['top'+setNamePostFix] = jnp.flatnonzero(coords[:, 1] > yRange[1] - tol)
nodeSets['all_boundary'+setNamePostFix] = jnp.flatnonzero(
(coords[:, 0] < xRange[0] + tol) |
(coords[:, 1] < yRange[0] + tol) |
(coords[:, 0] > xRange[1] - tol) |
(coords[:, 1] > yRange[1] - tol)
)

# Identify unique y-layers for nodes
unique_y_layers = sorted(onp.unique(coords[:, 1]))
# print("Unique y-layers identified:", unique_y_layers)

# Ensure we have the expected number of layers
assert len(unique_y_layers) == Ny, f"Expected {Ny} y-layers, but found {len(unique_y_layers)}."

# Initialize an empty list to store rows of nodes
y_layer_rows = []

# Map nodes to y_layers and construct rows
for i, y_val in enumerate(unique_y_layers):
nodes_in_layer = onp.flatnonzero(onp.abs(coords[:, 1] - y_val) < tol)
y_layer_rows.append(nodes_in_layer)
# print(f"Nodes in y-layer {i+1} (y = {y_val}):", nodes_in_layer)

# Convert list of rows into a structured 2D JAX array, padding with -1
max_nodes_per_layer = max(len(row) for row in y_layer_rows)
y_layers = -1 * jnp.ones((len(y_layer_rows), max_nodes_per_layer), dtype=int) # Initialize with -1

for i, row in enumerate(y_layer_rows):
y_layers = y_layers.at[i, :len(row)].set(row) # Fill each row with nodes from the layer

# Print for debugging
# print("y_layers (2D array):", y_layers)

def is_edge_on_left(xyOnEdge):
return np.all( xyOnEdge[:,0] < xRange[0] + tol )

def is_edge_on_bottom(xyOnEdge):
return np.all( xyOnEdge[:,1] < yRange[0] + tol )

def is_edge_on_right(xyOnEdge):
return np.all( xyOnEdge[:,0] > xRange[1] - tol )

def is_edge_on_top(xyOnEdge):
return np.all( xyOnEdge[:,1] > yRange[1] - tol )

sideSets = {}
sideSets['left'+setNamePostFix] = Surface.create_edges(coords, conns, is_edge_on_left)
sideSets['bottom'+setNamePostFix] = Surface.create_edges(coords, conns, is_edge_on_bottom)
sideSets['right'+setNamePostFix] = Surface.create_edges(coords, conns, is_edge_on_right)
sideSets['top'+setNamePostFix] = Surface.create_edges(coords, conns, is_edge_on_top)

allBoundaryEdges = np.vstack([s for s in sideSets.values()])
sideSets['all_boundary'+setNamePostFix] = allBoundaryEdges

blocks = {'block'+setNamePostFix: np.arange(conns.shape[0])}
mesh = construct_mesh_from_basic_data(coords, conns, blocks, nodeSets, sideSets)

return mesh, vmap(initial_disp_func)(mesh.coords), y_layers








Loading