|
| 1 | +from jax import jit |
| 2 | +from optimism import EquationSolver |
| 3 | +from optimism import VTKWriter |
| 4 | +from optimism import FunctionSpace |
| 5 | +from optimism import Interpolants |
| 6 | +from optimism import Mechanics |
| 7 | +from optimism import Mesh |
| 8 | +from optimism import Objective |
| 9 | +from optimism import QuadratureRule |
| 10 | +from optimism import ReadExodusMesh |
| 11 | +from optimism import SparseMatrixAssembler |
| 12 | +from optimism.FunctionSpace import DofManager |
| 13 | +from optimism.FunctionSpace import EssentialBC |
| 14 | +from optimism.material import Neohookean |
| 15 | + |
| 16 | +import jax.numpy as np |
| 17 | +from optimism.inverse import AdjointFunctionSpace |
| 18 | +from collections import namedtuple |
| 19 | + |
| 20 | +EnergyFunctions = namedtuple('EnergyFunctions', |
| 21 | + ['energy_function_coords']) |
| 22 | + |
| 23 | +# simulation parameterized on mesh node coordinates |
| 24 | +class CoordinateParameterizedSimulation: |
| 25 | + |
| 26 | + def __init__(self): |
| 27 | + self.writeOutput = True |
| 28 | + |
| 29 | + self.quad_rule = QuadratureRule.create_quadrature_rule_on_triangle(degree=2) |
| 30 | + self.lineQuadRule = QuadratureRule.create_quadrature_rule_1D(degree=2) |
| 31 | + |
| 32 | + self.ebcs = [ |
| 33 | + EssentialBC(nodeSet='bottom_sideset', component=1), |
| 34 | + |
| 35 | + EssentialBC(nodeSet='right_sideset', component=0), |
| 36 | + EssentialBC(nodeSet='left_sideset', component=0), |
| 37 | + ] |
| 38 | + |
| 39 | + shearModulus = 0.855 # MPa |
| 40 | + bulkModulus = 1000*shearModulus # MPa |
| 41 | + youngModulus = 9.0*bulkModulus*shearModulus / (3.0*bulkModulus + shearModulus) |
| 42 | + poissonRatio = (3.0*bulkModulus - 2.0*shearModulus) / 2.0 / (3.0*bulkModulus + shearModulus) |
| 43 | + props = { |
| 44 | + 'elastic modulus': youngModulus, |
| 45 | + 'poisson ratio': poissonRatio, |
| 46 | + 'version': 'coupled' |
| 47 | + } |
| 48 | + self.mat_model = Neohookean.create_material_model_functions(props) |
| 49 | + |
| 50 | + self.eq_settings = EquationSolver.get_settings(max_trust_iters=2500, use_preconditioned_inner_product_for_cg=True) |
| 51 | + |
| 52 | + self.input_mesh = './snap_cell.exo' |
| 53 | + |
| 54 | + self.maxForce = -3.0 |
| 55 | + self.plot_file = 'force_control_response.npz' |
| 56 | + self.stages = 1 |
| 57 | + steps_per_stage = 65 |
| 58 | + self.steps = self.stages * steps_per_stage |
| 59 | + |
| 60 | + def create_field(self, Uu): |
| 61 | + def get_ubcs(): |
| 62 | + V = np.zeros(self.mesh.coords.shape) |
| 63 | + return self.dof_manager.get_bc_values(V) |
| 64 | + |
| 65 | + return self.dof_manager.create_field(Uu, get_ubcs()) |
| 66 | + |
| 67 | + def reload_mesh(self): |
| 68 | + origMesh = ReadExodusMesh.read_exodus_mesh(self.input_mesh) |
| 69 | + self.mesh = Mesh.create_higher_order_mesh_from_simplex_mesh(origMesh, order=2, createNodeSetsFromSideSets=True) |
| 70 | + |
| 71 | + func_space = FunctionSpace.construct_function_space(self.mesh, self.quad_rule) |
| 72 | + self.dof_manager = DofManager(func_space, 2, self.ebcs) |
| 73 | + |
| 74 | + surfaceXCoords = self.mesh.coords[self.mesh.nodeSets['top_sideset']][:,0] |
| 75 | + self.tractionArea = np.max(surfaceXCoords) - np.min(surfaceXCoords) |
| 76 | + |
| 77 | + self.stateNotStored = True |
| 78 | + self.state = [] |
| 79 | + |
| 80 | + def run_simulation(self): |
| 81 | + |
| 82 | + func_space = FunctionSpace.construct_function_space(self.mesh, self.quad_rule) |
| 83 | + mech_funcs = Mechanics.create_mechanics_functions(func_space, mode2D='plane strain', materialModel=self.mat_model) |
| 84 | + |
| 85 | + # methods defined on the fly |
| 86 | + def energy_function(Uu, p): |
| 87 | + U = self.create_field(Uu) |
| 88 | + internal_variables = p.state_data |
| 89 | + strainEnergy = mech_funcs.compute_strain_energy(U, internal_variables) |
| 90 | + |
| 91 | + F = p.bc_data |
| 92 | + def force_function(x, n): |
| 93 | + return np.array([0.0, F/self.tractionArea]) |
| 94 | + |
| 95 | + loadPotential = Mechanics.compute_traction_potential_energy( |
| 96 | + func_space, U, self.lineQuadRule, self.mesh.sideSets['top_sideset'], |
| 97 | + force_function) |
| 98 | + |
| 99 | + return strainEnergy + loadPotential |
| 100 | + |
| 101 | + def assemble_sparse(Uu, p): |
| 102 | + U = self.create_field(Uu) |
| 103 | + internal_variables = p.state_data |
| 104 | + element_stiffnesses = mech_funcs.compute_element_stiffnesses(U, internal_variables) |
| 105 | + return SparseMatrixAssembler.\ |
| 106 | + assemble_sparse_stiffness_matrix(element_stiffnesses, func_space.mesh.conns, self.dof_manager) |
| 107 | + |
| 108 | + def store_force_displacement(Uu, force_val, force, disp): |
| 109 | + U = self.create_field(Uu) |
| 110 | + |
| 111 | + index = (self.mesh.nodeSets['top_sideset'], 1) |
| 112 | + |
| 113 | + force.append( force_val ) |
| 114 | + disp.append(np.mean(U.at[index].get())) |
| 115 | + |
| 116 | + with open(self.plot_file,'wb') as f: |
| 117 | + np.savez(f, force=force, displacement=disp) |
| 118 | + |
| 119 | + def write_vtk_output(Uu, p, step): |
| 120 | + U = self.create_field(Uu) |
| 121 | + plotName = 'output-'+str(step).zfill(3) |
| 122 | + writer = VTKWriter.VTKWriter(self.mesh, baseFileName=plotName) |
| 123 | + |
| 124 | + writer.add_nodal_field(name='displ', nodalData=U, fieldType=VTKWriter.VTKFieldType.VECTORS) |
| 125 | + |
| 126 | + energyDensities = mech_funcs.compute_output_energy_densities_and_stresses(U, p.state_data)[0] |
| 127 | + cellEnergyDensities = FunctionSpace.project_quadrature_field_to_element_field(func_space, energyDensities) |
| 128 | + writer.add_cell_field(name='strain_energy_density', |
| 129 | + cellData=cellEnergyDensities, |
| 130 | + fieldType=VTKWriter.VTKFieldType.SCALARS) |
| 131 | + writer.write() |
| 132 | + |
| 133 | + # problem set up |
| 134 | + Uu = self.dof_manager.get_unknown_values(np.zeros(self.mesh.coords.shape)) |
| 135 | + ivs = mech_funcs.compute_initial_state() |
| 136 | + p = Objective.Params(bc_data=0., state_data=ivs) |
| 137 | + precond_strategy = Objective.PrecondStrategy(assemble_sparse) |
| 138 | + self.objective = Objective.Objective(energy_function, Uu, p, precond_strategy) |
| 139 | + |
| 140 | + # loop over load steps |
| 141 | + force = 0. |
| 142 | + fd_force = [] |
| 143 | + fd_disp = [] |
| 144 | + |
| 145 | + store_force_displacement(Uu, force, fd_force, fd_disp) |
| 146 | + self.state.append((Uu, p)) |
| 147 | + |
| 148 | + steps_per_stage = int(self.steps / self.stages) |
| 149 | + force_inc = self.maxForce / steps_per_stage |
| 150 | + for step in range(1, steps_per_stage+1): |
| 151 | + print('--------------------------------------') |
| 152 | + print('LOAD STEP ', step) |
| 153 | + force += force_inc |
| 154 | + p = Objective.param_index_update(p, 0, force) |
| 155 | + Uu, solverSuccess = EquationSolver.nonlinear_equation_solve(self.objective, Uu, p, self.eq_settings) |
| 156 | + |
| 157 | + store_force_displacement(Uu, force, fd_force, fd_disp) |
| 158 | + self.state.append((Uu, p)) |
| 159 | + |
| 160 | + if self.writeOutput: |
| 161 | + write_vtk_output(Uu, p, step + 1) |
| 162 | + |
| 163 | + self.stateNotStored = False |
| 164 | + |
| 165 | + # energy functions for computing optimization quantities of interest |
| 166 | + def setup_energy_functions(self): |
| 167 | + shapeOnRef = Interpolants.compute_shapes(self.mesh.parentElement, self.quad_rule.xigauss) |
| 168 | + |
| 169 | + def energy_function_all_dofs(U, p, coords): |
| 170 | + adjoint_func_space = AdjointFunctionSpace.construct_function_space_for_adjoint(coords, shapeOnRef, self.mesh, self.quad_rule) |
| 171 | + mech_funcs = Mechanics.create_mechanics_functions(adjoint_func_space, mode2D='plane strain', materialModel=self.mat_model) |
| 172 | + ivs = p.state_data |
| 173 | + return mech_funcs.compute_strain_energy(U, ivs) |
| 174 | + |
| 175 | + def energy_function_coords(Uu, p, coords): |
| 176 | + U = self.create_field(Uu) |
| 177 | + return energy_function_all_dofs(U, p, coords) |
| 178 | + |
| 179 | + return EnergyFunctions(energy_function_coords) |
| 180 | + |
| 181 | + def compute_energy_quantities(self, uSteps, pSteps, coordinates, energy_function_coords): |
| 182 | + index = (self.mesh.nodeSets['top_sideset'], 1) |
| 183 | + |
| 184 | + totalWork = 0.0 |
| 185 | + complementaryWork = 0.0 |
| 186 | + totalWorkStored = [] |
| 187 | + complementaryWorkStored = [] |
| 188 | + strainEnergyStored = [] |
| 189 | + dissipatedEnergyStored = [] |
| 190 | + for step in range(1, self.steps+1): |
| 191 | + Uu = uSteps[step] |
| 192 | + p = pSteps[step] |
| 193 | + U = self.create_field(Uu) |
| 194 | + force = p.bc_data |
| 195 | + disp = np.mean(U.at[index].get()) |
| 196 | + |
| 197 | + Uu_prev = uSteps[step-1] |
| 198 | + p_prev = pSteps[step-1] |
| 199 | + U_prev = self.create_field(Uu_prev) |
| 200 | + force_prev = p_prev.bc_data |
| 201 | + disp_prev = np.mean(U_prev.at[index].get()) |
| 202 | + |
| 203 | + totalWork += 0.5*(force + force_prev)*(disp - disp_prev) |
| 204 | + complementaryWork += 0.5*(force - force_prev)*(disp + disp_prev) |
| 205 | + |
| 206 | + totalWorkStored.append(totalWork) |
| 207 | + complementaryWorkStored.append(complementaryWork) |
| 208 | + strainEnergyStored.append(energy_function_coords(Uu, p, coordinates)) |
| 209 | + dissipatedEnergyStored.append(totalWork - energy_function_coords(Uu, p, coordinates)) |
| 210 | + |
| 211 | + print("\n Quantities of Interest:") |
| 212 | + print(f"total work: {totalWork}") |
| 213 | + print(f"complementary work: {complementaryWork}") |
| 214 | + print(f"strain energy: {energy_function_coords(Uu, p, coordinates)}") |
| 215 | + print(f"dissipated energy: {totalWork - energy_function_coords(Uu, p, coordinates)}") |
| 216 | + |
| 217 | + with open("energy_histories.npz",'wb') as f: |
| 218 | + np.savez(f, totalWork=totalWorkStored, complementaryWork=complementaryWorkStored, strainEnergy=strainEnergyStored, dissipatedEnergy=dissipatedEnergyStored) |
| 219 | + |
| 220 | + def compute_qois(self): |
| 221 | + if self.stateNotStored: |
| 222 | + self.run_simulation() |
| 223 | + |
| 224 | + parameters = self.mesh.coords |
| 225 | + energyFuncs = self.setup_energy_functions() |
| 226 | + |
| 227 | + uSteps = np.stack([self.state[i][0] for i in range(0, self.steps+1)], axis=0) |
| 228 | + pSteps = [self.state[i][1] for i in range(0, self.steps+1)] |
| 229 | + |
| 230 | + self.compute_energy_quantities(uSteps, pSteps, parameters, jit(energyFuncs.energy_function_coords)) |
| 231 | + |
| 232 | + |
| 233 | + |
| 234 | +if __name__ == '__main__': |
| 235 | + sim = CoordinateParameterizedSimulation() |
| 236 | + sim.reload_mesh() |
| 237 | + sim.compute_qois() |
0 commit comments