Skip to content

Ralberd/third medium sandbox #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b814c04
adds material_qoi output for multi blocks
ralberd Jan 15, 2025
09e1ab7
updates ReadExodusMesh to work if there is no element map
ralberd Jan 15, 2025
187e037
adds material_qoi to NeoHookean material; stores det(F)
ralberd Jan 15, 2025
94ac054
adds ThirdMediumNeoHookean material for testing third medium formulat…
ralberd Jan 15, 2025
197fcad
adds test to try out different energy formulations for third medium
ralberd Jan 15, 2025
ef7706b
Merge branch 'main' into ralberd/third_medium_sandbox
ralberd Jan 15, 2025
dacc64e
Merge branch 'main' into ralberd/third_medium_sandbox
ralberd Jan 15, 2025
bcb3301
adds quadratic lagrange triangle elements
ralberd Feb 9, 2025
4ff5fb1
adds Lagrange elements option to create_higher_order_mesh_from_simple…
ralberd Feb 11, 2025
2013f1c
fix for interpolants test
ralberd Feb 11, 2025
18b1b6d
fixes issue with TestFixture import
ralberd Feb 11, 2025
cf81b55
adds patch tests for quadratic Lagrange elements; removes duplicated …
ralberd Feb 11, 2025
631a0c1
implements shape function second derivatives for quadratic lagrange e…
ralberd Feb 14, 2025
3b13bf0
adds functions for computing field hessians to FunctionSpace
ralberd Feb 14, 2025
b815778
adds hessian argument for direct FunctionSpace construction to constr…
ralberd Feb 14, 2025
5f23c49
concatenates shape gradients and hessians into same container in orde…
ralberd Feb 17, 2025
a2a2a2a
testing Teran invertible Neo Hookean in third medium
ralberd Feb 17, 2025
33e61ce
small fix for change to displacement gradient storage
ralberd Feb 19, 2025
56a8aab
fixes adjoint function space to match changes in FunctionSpace
ralberd Feb 19, 2025
2aea3bf
update Neohookean to work with changes to displacement gradient storage
ralberd Feb 19, 2025
a350a48
adds regularization term to third medium material model
ralberd Feb 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions optimism/FunctionSpace.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,40 @@ def construct_function_space_from_parent_element(mesh, shapeOnRef, quadratureRul
isAxisymmetric = True
vols = jax.vmap(el_vols, (None, 0, None, 0, None))(mesh.coords, mesh.conns, mesh.parentElement, shapes, quadratureRule.wgauss)

return FunctionSpace(shapes, vols, shapeGrads, mesh, quadratureRule, isAxisymmetric)
if mesh.parentElement.elementType == Interpolants.LAGRANGE_TRIANGLE_ELEMENT:
shapeOnRefHessians = Interpolants.shape2d_lagrange_second_derivatives(mesh.parentElement.degree, mesh.parentElement.coordinates, quadratureRule.xigauss)
shapeHessians = jax.vmap(map_element_shape_hessians, (None, 0, None, None, None))(mesh.coords, mesh.conns, mesh.parentElement, shapeOnRef.gradients, shapeOnRefHessians)
return FunctionSpace(shapes, vols, np.concatenate((shapeGrads, shapeHessians), axis=-1), mesh, quadratureRule, isAxisymmetric)
else:
return FunctionSpace(shapes, vols, shapeGrads, mesh, quadratureRule, isAxisymmetric)


def map_element_shape_grads(coordField, nodeOrdinals, parentElement, shapeGradients):
Xn = coordField.take(nodeOrdinals,0)
v = Xn[parentElement.vertexNodes]
J = np.column_stack((v[0] - v[2], v[1] - v[2]))
J = np.column_stack((v[0] - v[2], v[1] - v[2])) # assumes simplex element
return jax.vmap(lambda dN: solve(J.T, dN.T).T)(shapeGradients)


def map_element_shape_hessians(coordField, nodeOrdinals, parentElement, shapeGradients, shapeHessians):
Xn = coordField.take(nodeOrdinals,0)

dNdX = map_element_shape_grads(coordField, nodeOrdinals, parentElement, shapeGradients)

def quad_point_shape_hessian(shapeGrads, dNdX, shapeHessians):
dXdXi = np.tensordot(Xn, shapeGrads, axes=[0,0])
J = np.array([[dXdXi[0,0]**2, dXdXi[1,0]**2, 2.0*dXdXi[1,0]*dXdXi[0,0]],
[dXdXi[0,1]**2, dXdXi[1,1]**2, 2.0*dXdXi[1,1]*dXdXi[0,1]],
[dXdXi[0,0]*dXdXi[0,1], dXdXi[1,0]*dXdXi[1,1], dXdXi[1,0]*dXdXi[0,1] + dXdXi[0,0]*dXdXi[1,1]]])

d2X_dXi2 = np.tensordot(Xn, shapeHessians, axes=[0,0])
b = np.array([shapeHessians[:,0].T - dNdX[:,0].T * d2X_dXi2[0,0] - dNdX[:,1].T * d2X_dXi2[1,0],
shapeHessians[:,1].T - dNdX[:,0].T * d2X_dXi2[0,1] - dNdX[:,1].T * d2X_dXi2[1,1],
shapeHessians[:,2].T - dNdX[:,0].T * d2X_dXi2[0,2] - dNdX[:,1].T * d2X_dXi2[1,2]])
return solve(J, b).T
return jax.vmap(quad_point_shape_hessian, (0, 0, 0))(shapeGradients, dNdX, shapeHessians) # vmap over nqpe


def compute_element_volumes(coordField, nodeOrdinals, parentElement, shapes, weights):
Xn = coordField.take(nodeOrdinals,0)
v = Xn[parentElement.vertexNodes]
Expand Down
206 changes: 192 additions & 14 deletions optimism/Interpolants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from jaxtyping import Array, Float, Int
from scipy import special
from enum import Enum, auto
import numpy as onp
import equinox as eqx
import jax.numpy as np

class InterpolationType(Enum):
LOBATTO = auto()
LAGRANGE = auto()

class ParentElement(eqx.Module):
"""Finite element on reference domain.
Expand Down Expand Up @@ -44,9 +48,9 @@
is the number of nodes in the element (which is equal to the
number of shape functions).
gradients: Values of the parametric gradients of the shape functions.
Shape is ``(nPts, nDim, nNodes)``, where ``nDim`` is the number
Shape is ``(nPts, nNodes, nDim)``, where ``nDim`` is the number
of spatial dimensions. Line elements are an exception, which
have shape ``(nPts, nNdodes)``.
have shape ``(nNodes, nPts)``.
"""
values: Float[Array, "nq nn"]
gradients: Float[Array, "nq nn nd"]
Expand All @@ -59,6 +63,8 @@
LINE_ELEMENT = 0
TRIANGLE_ELEMENT = 1
TRIANGLE_ELEMENT_WITH_BUBBLE = 2
LAGRANGE_LINE_ELEMENT = 3
LAGRANGE_TRIANGLE_ELEMENT = 4


def make_parent_elements(degree):
Expand Down Expand Up @@ -86,6 +92,33 @@
return xn


def make_lagrange_parent_element_1d(degree):
"""Lagrange Interpolation points on the unit interval [0, 1].
Only implemented for second degree
"""
if degree != 2:
raise NotImplementedError

xn = np.array([0.0, 0.5, 1.0])
vertexPoints = np.array([0, 2], dtype=np.int32)
interiorPoints = np.array([1], dtype=np.int32)
return ParentElement(LAGRANGE_LINE_ELEMENT, int(degree), xn, vertexPoints, None, interiorPoints)


def vander1d(x, degree):
x = onp.asarray(x)
A = onp.zeros((x.shape[0], degree + 1))
dA = onp.zeros((x.shape[0], degree + 1))
domain = [0.0, 1.0]
for i in range(degree + 1):
p = onp.polynomial.Legendre.basis(i, domain=domain)
p *= onp.sqrt(2.0*i + 1.0) # keep polynomial orthonormal
A[:, i] = p(x)
dp = p.deriv()
dA[:, i] = dp(x)
return A, dA


def shape1d(degree, nodalPoints, evaluationPoints):
"""Evaluate shape functions and derivatives at points in the master element.

Expand All @@ -108,18 +141,34 @@
return ShapeFunctions(shape, dshape)


def vander1d(x, degree):
x = onp.asarray(x)
A = onp.zeros((x.shape[0], degree + 1))
dA = onp.zeros((x.shape[0], degree + 1))
domain = [0.0, 1.0]
for i in range(degree + 1):
p = onp.polynomial.Legendre.basis(i, domain=domain)
p *= onp.sqrt(2.0*i + 1.0) # keep polynomial orthonormal
A[:, i] = p(x)
dp = p.deriv()
dA[:, i] = dp(x)
return A, dA
def shape1d_lagrange(degree, nodalPoints, evaluationPoints):
"""Evaluate Lagrange shape functions and derivatives at points in the parent element.
Only implemented for second degree

Returns:
Shape function values and shape function derivatives at ``evaluationPoints``,
in a tuple (``shape``, ``dshape``).
shapes: [nNodes, nEvalPoints]
dshapes: [nNodes, nEvalPoints]
"""
if degree != 2:
raise NotImplementedError

denom1 = (nodalPoints[0] - nodalPoints[1]) * (nodalPoints[0] - nodalPoints[2])
denom2 = (nodalPoints[1] - nodalPoints[0]) * (nodalPoints[1] - nodalPoints[2])
denom3 = (nodalPoints[2] - nodalPoints[0]) * (nodalPoints[2] - nodalPoints[1])

shape1 = (evaluationPoints - nodalPoints[1])*(evaluationPoints - nodalPoints[2]) / denom1
shape2 = (evaluationPoints - nodalPoints[0])*(evaluationPoints - nodalPoints[2]) / denom2
shape3 = (evaluationPoints - nodalPoints[0])*(evaluationPoints - nodalPoints[1]) / denom3
shape = np.stack((shape1, shape2, shape3))

dshape1 = (2.0*evaluationPoints - nodalPoints[2] - nodalPoints[1]) / denom1
dshape2 = (2.0*evaluationPoints - nodalPoints[2] - nodalPoints[0]) / denom2
dshape3 = (2.0*evaluationPoints - nodalPoints[1] - nodalPoints[0]) / denom3
dshape = np.stack((dshape1, dshape2, dshape3))

return ShapeFunctions(shape, dshape)


def make_parent_element_2d(degree):
Expand Down Expand Up @@ -169,6 +218,27 @@

return ParentElement(TRIANGLE_ELEMENT, int(degree), points, vertexPoints, facePoints, interiorPoints)

def make_lagrange_parent_element_2d(degree):
"""Lagrange interpolation points on the triangle
Only implemented for second degree triangles.

Convention for numbering:

2
o
| \
4 o o 1
| \
o--o--o
5 3 0
"""
if degree != 2:
raise NotImplementedError

xn = np.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0], [0.5, 0.0], [0.0, 0.5], [0.0, 0.0]])
vertexPoints = np.array([0, 2, 5], dtype=np.int32)
facePoints = np.array([[0, 1, 2], [2, 4, 5], [5, 3, 0]], dtype=np.int32)
return ParentElement(LAGRANGE_TRIANGLE_ELEMENT, int(degree), xn, vertexPoints, facePoints, np.array([], dtype=np.int32))

def pascal_triangle_monomials(degree):
p = []
Expand Down Expand Up @@ -261,13 +331,121 @@
return ShapeFunctions(np.asarray(shapes), np.asarray(dshapes))


def shape2d_lagrange(degree, nodalPoints, evaluationPoints):
"""Evaluate Lagrange shape functions and derivatives at points in the parent element.
Only implemented for second degree

Reference:
T. Hughes. "The Finite Element Method"
Appendix 3.I
"""
if degree != 2:
raise NotImplementedError

numEvalPoints = evaluationPoints.shape[0]
r = evaluationPoints[:,0]
s = evaluationPoints[:,1]
# t = 1.0 - r - s

shape0 = 2.0*r*r - r # r * (2.0 * r - 1.0)
shape1 = 4.0 * r * s # 4.0 * r * s
shape2 = 2.0*s*s - s # s * (2.0 * s - 1.0)
shape3 = 4.0*(r - r*s - r*r) # 4.0 * r * t
shape4 = 4.0*(s - r*s - s*s) # 4.0 * s * t
shape5 = 1.0 - 3.0*(r + s) + 4.0*r*s + 2.0*(r*r + s*s) # t * (2.0 * t - 1.0)
shape = np.stack((shape0, shape1, shape2, shape3, shape4, shape5)).T

dshape0_dr = 4.0*r - 1.0
dshape0_ds = np.zeros(numEvalPoints)
dshape1_dr = 4.0*s
dshape1_ds = 4.0*r
dshape2_dr = np.zeros(numEvalPoints)
dshape2_ds = 4.0*s - 1.0
dshape3_dr = 4.0*(1.0 - s - 2.0*r)
dshape3_ds = -4.0*r
dshape4_dr = -4.0*s
dshape4_ds = 4.0*(1.0 - r - 2.0*s)
dshape5_dr = 4.0*(r + s) - 3.0
dshape5_ds = 4.0*(r + s) - 3.0
dshape_dr = np.stack((dshape0_dr, dshape1_dr, dshape2_dr, dshape3_dr, dshape4_dr, dshape5_dr)).T
dshape_ds = np.stack((dshape0_ds, dshape1_ds, dshape2_ds, dshape3_ds, dshape4_ds, dshape5_ds)).T
dshape = np.stack((dshape_dr, dshape_ds), axis=2)

return ShapeFunctions(shape, dshape)


def shape2d_lagrange_second_derivatives(degree, nodalPoints, evaluationPoints):
"""Evaluate second derivatives of Lagrange shape functions at points in the parent element.
Only implemented for second degree

Shape of returned second derivatives is ``(nPts, nNodes, nTerms)``, where ``nTerms``
is the number of derivative terms (3). These terms are organized as
[ (d^2 N)/(dr^2), (d^2 N)/(ds^2), (d^2 N)/(dr ds) ]
where r and s are the local triangle coordinates.

Following shape2d_lagrange above, the shape functions are

N0 = r * (2.0 * r - 1.0) = 2.0*r*r - r
N1 = 4.0 * r * s = 4.0 * r * s
N2 = s * (2.0 * s - 1.0) = 2.0*s*s - s
N3 = 4.0 * r * t = 4.0*(r - r*s - r*r)
N4 = 4.0 * s * t = 4.0*(s - r*s - s*s)
N5 = t * (2.0 * t - 1.0) = 1.0 - 3.0*(r + s) + 4.0*r*s + 2.0*(r*r + s*s)

Reference:
T. Hughes. "The Finite Element Method"
Appendix 3.I
"""
if degree != 2:
raise NotImplementedError

numEvalPoints = evaluationPoints.shape[0]
r = evaluationPoints[:,0]
s = evaluationPoints[:,1]
# t = 1.0 - r - s

d2shape0_dr2 = 4.0 * np.ones(numEvalPoints)
d2shape0_ds2 = np.zeros(numEvalPoints)
d2shape0_drds = np.zeros(numEvalPoints)

d2shape1_dr2 = np.zeros(numEvalPoints)
d2shape1_ds2 = np.zeros(numEvalPoints)
d2shape1_drds = 4.0 * np.ones(numEvalPoints)

d2shape2_dr2 = np.zeros(numEvalPoints)
d2shape2_ds2 = 4.0 * np.ones(numEvalPoints)
d2shape2_drds = np.zeros(numEvalPoints)

d2shape3_dr2 = -8.0 * np.ones(numEvalPoints)
d2shape3_ds2 = np.zeros(numEvalPoints)
d2shape3_drds = -4.0 * np.ones(numEvalPoints)

d2shape4_dr2 = np.zeros(numEvalPoints)
d2shape4_ds2 = -8.0 * np.ones(numEvalPoints)
d2shape4_drds = -4.0 * np.ones(numEvalPoints)

d2shape5_dr2 = 4.0 * np.ones(numEvalPoints)
d2shape5_ds2 = 4.0 * np.ones(numEvalPoints)
d2shape5_drds = 4.0 * np.ones(numEvalPoints)

d2shape_dr2 = np.stack((d2shape0_dr2, d2shape1_dr2, d2shape2_dr2, d2shape3_dr2, d2shape4_dr2, d2shape5_dr2)).T
d2shape_ds2 = np.stack((d2shape0_ds2, d2shape1_ds2, d2shape2_ds2, d2shape3_ds2, d2shape4_ds2, d2shape5_ds2)).T
d2shape_drds = np.stack((d2shape0_drds, d2shape1_drds, d2shape2_drds, d2shape3_drds, d2shape4_drds, d2shape5_drds)).T

return np.stack((d2shape_dr2, d2shape_ds2, d2shape_drds), axis=2)


def compute_shapes(parentElement, evaluationPoints):
if parentElement.elementType == LINE_ELEMENT:
return shape1d(parentElement.degree, parentElement.coordinates, evaluationPoints)
elif parentElement.elementType == TRIANGLE_ELEMENT:
return shape2d(parentElement.degree, parentElement.coordinates, evaluationPoints)
elif parentElement.elementType == TRIANGLE_ELEMENT_WITH_BUBBLE:
return shape2dBubble(parentElement, evaluationPoints)
elif parentElement.elementType == LAGRANGE_LINE_ELEMENT:
return shape1d_lagrange(parentElement.degree, parentElement.coordinates, evaluationPoints)

Check warning on line 446 in optimism/Interpolants.py

View check run for this annotation

Codecov / codecov/patch

optimism/Interpolants.py#L446

Added line #L446 was not covered by tests
elif parentElement.elementType == LAGRANGE_TRIANGLE_ELEMENT:
return shape2d_lagrange(parentElement.degree, parentElement.coordinates, evaluationPoints)
else:
raise ValueError('Unknown element type.')

Expand Down
37 changes: 32 additions & 5 deletions optimism/Mechanics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@


def plane_strain_gradient_transformation(elemDispGrads, elemShapes, elemVols, elemNodalDisps, elemNodalCoords):
return vmap(tensor_2D_to_3D)(elemDispGrads)
def map_2D_tensor_to_3D(H):
return np.zeros((3,H.shape[1]+1)).at[ 0:H.shape[0], 0:2 ].set(H[:,0:2]).at[ 0:H.shape[0], 3: ].set(H[:,2:])
return vmap(map_2D_tensor_to_3D)(elemDispGrads)


def volume_average_J_gradient_transformation(elemDispGrads, elemVols, pShapes):
Expand Down Expand Up @@ -252,15 +254,40 @@
elemIds = fs.mesh.blocks[blockKey]
blockEnergyDensities, blockStresses = FunctionSpace.evaluate_on_block(fs, U, stateVariables, dt, output_constitutive, elemIds, modify_element_gradient=modify_element_gradient)
energy_densities = energy_densities.at[elemIds].set(blockEnergyDensities)
stresses = stresses.at[elemIds].set(blockStresses)
stresses = stresses.at[elemIds].set(blockStresses[:,:,0:3,0:3])

Check warning on line 257 in optimism/Mechanics.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mechanics.py#L257

Added line #L257 was not covered by tests
return energy_densities, stresses


def compute_initial_state():
return _compute_initial_state_multi_block(fs, materialModels)


return MechanicsFunctions(compute_strain_energy, jit(compute_updated_internal_variables), jit(compute_element_stiffnesses), jit(compute_output_energy_densities_and_stresses), compute_initial_state, None, None)
def qoi_to_lagrangian_qoi(compute_material_qoi):
def L(U, gradU, Q, X, dt):
return compute_material_qoi(gradU, Q, dt)
return L

Check warning on line 266 in optimism/Mechanics.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mechanics.py#L264-L266

Added lines #L264 - L266 were not covered by tests

def integrated_material_qoi(U, stateVariables, dt=0.0):
material_qoi = np.zeros((Mesh.num_elements(fs.mesh), len(fs.quadratureRule)))
for blockKey in materialModels:
elemIds = fs.mesh.blocks[blockKey]
block_qoi = FunctionSpace.integrate_over_block(fs, U, stateVariables, dt,

Check warning on line 272 in optimism/Mechanics.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mechanics.py#L269-L272

Added lines #L269 - L272 were not covered by tests
qoi_to_lagrangian_qoi(materialModels[blockKey].compute_material_qoi),
elemIds,
modify_element_gradient=modify_element_gradient)
material_qoi = material_qoi.at[elemIds].set(block_qoi)
return material_qoi

Check warning on line 277 in optimism/Mechanics.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mechanics.py#L276-L277

Added lines #L276 - L277 were not covered by tests

def compute_output_material_qoi(U, stateVariables, dt=0.0):
material_qoi = np.zeros((Mesh.num_elements(fs.mesh), len(fs.quadratureRule)))
for blockKey in materialModels:
elemIds = fs.mesh.blocks[blockKey]
block_qoi = FunctionSpace.evaluate_on_block(fs, U, stateVariables, dt,

Check warning on line 283 in optimism/Mechanics.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mechanics.py#L280-L283

Added lines #L280 - L283 were not covered by tests
qoi_to_lagrangian_qoi(materialModels[blockKey].compute_material_qoi),
elemIds,
modify_element_gradient=modify_element_gradient)
material_qoi = material_qoi.at[elemIds].set(block_qoi)
return material_qoi

Check warning on line 288 in optimism/Mechanics.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mechanics.py#L287-L288

Added lines #L287 - L288 were not covered by tests

return MechanicsFunctions(compute_strain_energy, jit(compute_updated_internal_variables), jit(compute_element_stiffnesses), jit(compute_output_energy_densities_and_stresses), compute_initial_state, integrated_material_qoi, jit(compute_output_material_qoi))


######
Expand Down
17 changes: 11 additions & 6 deletions optimism/Mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,20 @@
return edgeConns, edges


def create_higher_order_mesh_from_simplex_mesh(mesh, order, useBubbleElement=False, copyNodeSets=False, createNodeSetsFromSideSets=False):
def create_higher_order_mesh_from_simplex_mesh(mesh, order, interpolationType = Interpolants.InterpolationType.LOBATTO, useBubbleElement=False, copyNodeSets=False, createNodeSetsFromSideSets=False):
if order==1: return mesh

parentElement1d = Interpolants.make_parent_element_1d(order)

if useBubbleElement:
basis = Interpolants.make_parent_element_2d_with_bubble(order)
if interpolationType == Interpolants.InterpolationType.LAGRANGE:
if useBubbleElement:
raise NotImplementedError

Check warning on line 217 in optimism/Mesh.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mesh.py#L217

Added line #L217 was not covered by tests
parentElement1d = Interpolants.make_lagrange_parent_element_1d(order)
basis = Interpolants.make_lagrange_parent_element_2d(order)
else:
basis = Interpolants.make_parent_element_2d(order)
parentElement1d = Interpolants.make_parent_element_1d(order)
if useBubbleElement:
basis = Interpolants.make_parent_element_2d_with_bubble(order)

Check warning on line 223 in optimism/Mesh.py

View check run for this annotation

Codecov / codecov/patch

optimism/Mesh.py#L223

Added line #L223 was not covered by tests
else:
basis = Interpolants.make_parent_element_2d(order)

conns = np.zeros((num_elements(mesh), basis.coordinates.shape[0]), dtype=np.int_)

Expand Down
Loading