Skip to content

Commit 3639425

Browse files
authored
Merge pull request #28 from btalamini/feature/move_jax_settings_to_init
Feature/specify global jax settings in a more standard way
2 parents e9a1db5 + 592550c commit 3639425

File tree

20 files changed

+104
-90
lines changed

20 files changed

+104
-90
lines changed

examples/arch_bc/ArchBc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from matplotlib import pyplot as plt
22

3-
from optimism.JaxConfig import *
43
from optimism import EquationSolver as EqSolver
54
from optimism import FunctionSpace
65
from optimism import Interpolants

examples/arch_bc/ArchTraction.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from matplotlib import pyplot as plt
22

3-
from optimism.JaxConfig import *
43
from optimism import EquationSolver as EqSolver
54
from optimism import EquationSolverSubspace as SolverSubspace
65
from optimism import FunctionSpace

examples/buckle_slide/Sliding.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from optimism.JaxConfig import *
1+
from functools import partial
2+
from jax import jit
3+
from jax import numpy as np
24

35
from optimism import EquationSolver as EqSolver
46
from optimism import FunctionSpace
@@ -20,9 +22,6 @@
2022

2123
from optimism.test.MeshFixture import MeshFixture
2224

23-
import os
24-
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=2"
25-
2625
props = {'elastic modulus': 10.0,
2726
'poisson ratio': 0.3}
2827

examples/contact_mesh_on_mesh/TwoBodyContact.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from matplotlib import pyplot as plt
2+
from jax import numpy as np
23

3-
from optimism.JaxConfig import *
44
from optimism import EquationSolver as EqSolver
5-
65
from optimism import Mesh
76
from optimism import Mechanics
87
from optimism import FunctionSpace

examples/friction/CornerSlide.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from optimism.JaxConfig import *
2-
1+
from functools import partial
2+
import jax
3+
from jax import numpy as np
34

45
from optimism.material import Neohookean
56
from optimism import Mechanics
@@ -112,7 +113,7 @@ def constraint_func(Uu, p):
112113
self.constraint_func = constraint_func
113114

114115

115-
@partial(jit, static_argnums=0)
116+
@partial(jax.jit, static_argnums=0)
116117
def create_field(self, Uu, p):
117118
return self.dofManager.create_field(Uu, self.get_ubcs(p))
118119

@@ -150,7 +151,7 @@ def run(self):
150151
penalty = 2.5
151152
kappa0 = penalty * np.ones(lam0.shape)
152153

153-
hess_func = jit(hessian(lambda Uu,p: self.energy_func(Uu,lam0,p)))
154+
hess_func = jax.jit(jax.hessian(lambda Uu,p: self.energy_func(Uu,lam0,p)))
154155

155156
objective = ConstrainedQuasiObjective(self.energy_func,
156157
self.constraint_func,

examples/hemisphere_cap/hemisphere_plastic_disp_control.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from optimism.JaxConfig import *
1+
import jax
2+
from jax import numpy as np
3+
24
from optimism import EquationSolver as EqSolver
35
from optimism import FunctionSpace
46
from optimism import Interpolants
@@ -55,7 +57,7 @@ def compute_energy_from_bcs(Uu, Ubc, p):
5557
strainEnergy = self.bvpFuncs.compute_strain_energy(U, internalVariables)
5658
return strainEnergy
5759

58-
self.compute_bc_reactions = jit(grad(compute_energy_from_bcs, 1))
60+
self.compute_bc_reactions = jax.jit(jax.grad(compute_energy_from_bcs, 1))
5961

6062
self.trSettings = EqSolver.get_settings(max_trust_iters=400, t1=0.4, t2=1.5, eta1=1e-8, eta2=0.2, eta3=0.8, over_iters=100)
6163

examples/hemisphere_cap/hemisphere_plastic_load_control.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from optimism.JaxConfig import *
1+
import jax
2+
from jax import numpy as np
3+
24
from optimism import EquationSolver as EqSolver
35
from optimism import FunctionSpace
46
from optimism import Interpolants
@@ -31,11 +33,11 @@ def integrate_field_function_on_edge(quadratureRule, edge, mesh, func):
3133
jac = np.linalg.norm(edgeCoords[0,:] - edgeCoords[1,:])
3234
XGauss = edgeCoords[0] + np.outer(xigauss, edgeCoords[1] - edgeCoords[0])
3335
dX = jac*wgauss
34-
return np.dot(vmap(func)(XGauss), dX)
36+
return np.dot(jax.vmap(func)(XGauss), dX)
3537

3638

3739
def integrate_field_function_on_surface(quadratureRule, edges, mesh, func):
38-
F = vmap(integrate_field_function_on_edge, (None,0,None,None))
40+
F = jax.vmap(integrate_field_function_on_edge, (None,0,None,None))
3941
return np.sum(F(quadratureRule, edges, mesh, func))
4042

4143

@@ -94,7 +96,7 @@ def compute_energy_from_bcs(Uu, Ubc, p):
9496
loadPotential = TractionBC.compute_traction_potential_energy(self.mesh, U, self.lineQuadRule, self.mesh.sideSets['push'], lambda X: np.array([0.0, -F/self.pushArea]))
9597
return strainEnergy + loadPotential
9698

97-
self.compute_bc_reactions = jit(grad(compute_energy_from_bcs, 1))
99+
self.compute_bc_reactions = jax.jit(jax.grad(compute_energy_from_bcs, 1))
98100

99101
self.trSettings = EqSolver.get_settings(max_trust_iters=400, t1=0.4, t2=1.5, eta1=1e-8, eta2=0.2, eta3=0.8, over_iters=100)
100102

examples/inverse/BuckleInverse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from optimism.JaxConfig import *
1+
import jax
2+
from jax import numpy as np
23
from jax import custom_jvp, custom_vjp
34

45
from optimism import ReadMesh
@@ -197,7 +198,7 @@ def run(self):
197198
outputDisp = [0,]
198199
outputForce = [0,]
199200

200-
reaction_func = jit( lambda x,p: grad(self.energy_func,1)(x,p)[0][1] )
201+
reaction_func = jax.jit( lambda x,p: jax.grad(self.energy_func,1)(x,p)[0][1] )
201202

202203
N = 1
203204

@@ -254,7 +255,7 @@ def loss(chi):
254255
opt_state = opt_init(chi)
255256

256257
def step(step, opt_state):
257-
value, grads = value_and_grad(loss)(get_params(opt_state))
258+
value, grads = jax.value_and_grad(loss)(get_params(opt_state))
258259
opt_state = opt_update(step, grads, opt_state)
259260
return value, opt_state
260261

examples/inverse/arch/ArchInverse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from optimism.JaxConfig import *
1+
import jax
2+
from jax import numpy as np
23

34
from optimism import ReadMesh
45
from optimism.material import Neohookean
@@ -105,7 +106,7 @@ def energy_func(self, Uu, p):
105106

106107

107108
def reaction_func(self, Uu, p):
108-
return grad(self.energy_func,1)(Uu,p)[0]
109+
return jax.grad(self.energy_func,1)(Uu,p)[0]
109110

110111

111112
def compute_volume(self, p):
@@ -226,7 +227,7 @@ def debug_loss(chi):
226227
opt_state = opt_init(chi)
227228

228229
def step(step, opt_state):
229-
value, grads = value_and_grad(loss)(get_params(opt_state))
230+
value, grads = jax.value_and_grad(loss)(get_params(opt_state))
230231
opt_state = opt_update(step, grads, opt_state)
231232
return value, opt_state
232233

examples/material_test/MaterialTest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Demonstrate the material testing tools."""
22

3+
from jax import numpy as np
34
from matplotlib import pyplot as plt
45

5-
from optimism.JaxConfig import *
66
from optimism.material import MaterialUniaxialSimulator
77
from optimism.material import J2Plastic
88

examples/sharp_notch_fracture/SharpNotchFracture.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from functools import partial
12
import jax
3+
from jax import numpy as np
24

3-
from optimism.JaxConfig import *
45
from optimism import AlSolver
56
from optimism import BoundConstrainedSolver
67
from optimism import BoundConstrainedObjective
@@ -102,7 +103,7 @@ def energy_for_rxns(Ubc, Uu, p):
102103
internalVariables = p[1]
103104
return self.bvpFunctions.compute_internal_energy(U, internalVariables)
104105

105-
self.compute_reactions = jit(grad(energy_for_rxns))
106+
self.compute_reactions = jax.jit(jax.grad(energy_for_rxns))
106107

107108

108109
def plot_solution(self, U, p, lagrange, plotName):

examples/surfing_fracture/SurfingFracture.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from functools import partial
2+
import jax
3+
from jax import numpy as np
14
from matplotlib import pyplot as plt
2-
from scipy.sparse import diags as sparse_diags
35
import numpy as onp
4-
5-
from optimism.JaxConfig import *
6+
from scipy.sparse import diags as sparse_diags
67

78
from optimism import BoundConstrainedObjective
89
from optimism import ConstrainedObjective
@@ -64,7 +65,7 @@ def apply_mode_I_field_at_point(X, K_I):
6465

6566

6667
def apply_mode_I_Bc(boundaryNodeCoords, K_I):
67-
return vmap(apply_mode_I_field_at_point, (0, None))(boundaryNodeCoords, K_I)
68+
return jax.vmap(apply_mode_I_field_at_point, (0, None))(boundaryNodeCoords, K_I)
6869

6970

7071
crackDirection = np.array([1.0,0.0])
@@ -84,7 +85,7 @@ def J_integral(U, internals, mesh, fs, edges, bvpFuncs):
8485
stresses = FunctionSpace.project_quadrature_field_to_element_field(fs, stresses)
8586
Ws = FunctionSpace.project_quadrature_field_to_element_field(fs, Ws)
8687

87-
computeJs = vmap(compute_J_integral_on_edge, (None, 0, 0, 0, 0))
88+
computeJs = jax.vmap(compute_J_integral_on_edge, (None, 0, 0, 0, 0))
8889
return np.sum(computeJs(mesh,
8990
edges,
9091
Ws[edges[:,0]],
@@ -244,8 +245,7 @@ def get_ubcs(self, p):
244245
Xb = self.mesh.coords[self.mesh.nodeSets['external'],:]
245246
Xb = Xb.at[:,0].add(-origin)
246247
modeIBcs = apply_mode_I_Bc(Xb, KI)
247-
index = (self.mesh.nodeSets['external'],:2)
248-
V = V.at[index].set(modeIBcs)
248+
V = V.at[self.mesh.nodeSets['external'],:2].set(modeIBcs)
249249

250250
return self.dofManager.get_bc_values(V)
251251

examples/tension_axisymmetric/TensionAxisymmetric.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from functools import partial
2+
import jax
3+
from jax import numpy as np
14
from matplotlib import pyplot as plt
25

3-
from optimism.JaxConfig import *
46
from optimism import EquationSolver as EqSolver
57
from optimism import FunctionSpace
68
from optimism import Interpolants
@@ -84,8 +86,8 @@ def energy_function(self, Uu, p):
8486
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables)
8587

8688

87-
@partial(jit, static_argnums=0)
88-
@partial(value_and_grad, argnums=2)
89+
@partial(jax.jit, static_argnums=0)
90+
@partial(jax.value_and_grad, argnums=2)
8991
def compute_reactions_from_bcs(self, Uu, Ubc, internalVariables):
9092
U = self.dofManager.create_field(Uu, Ubc)
9193
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables)

examples/tension_axisymmetric/TensionAxisymmetricFracture.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from functools import partial
2+
import jax
3+
from jax import numpy as np
14
from scipy.sparse import diags as sp_sparse_diags
25

3-
from optimism.JaxConfig import *
46
from optimism import AlSolver
57
from optimism import BoundConstrainedSolver
68
from optimism import BoundConstrainedObjective
@@ -39,7 +41,7 @@
3941

4042
def compute_row_sum_gram_matrix(fs, dummyU, dummyInternalVars):
4143
func = lambda u, uGrad, q, x: u
42-
compute = grad(FunctionSpace.integrate_over_block, 1)
44+
compute = jax.grad(FunctionSpace.integrate_over_block, 1)
4345
return compute(fs, dummyU, dummyInternalVars, func, fs.mesh.blocks['block_1'])
4446

4547

@@ -112,8 +114,8 @@ def energy_function(self, Uu, p):
112114
return self.bvpFunctions.compute_internal_energy(U, internalVariables)
113115

114116

115-
@partial(jit, static_argnums=0)
116-
@partial(value_and_grad, argnums=2)
117+
@partial(jax.jit, static_argnums=0)
118+
@partial(jax.value_and_grad, argnums=2)
117119
def compute_reactions_from_bcs(self, Uu, Ubc, internalVariables):
118120
U = self.dofManager.create_field(Uu, Ubc)
119121
return self.bvpFunctions.compute_internal_energy(U, internalVariables)

examples/tension_axisymmetric/materialPointSimulation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from matplotlib import pyplot as plt
22

3-
from optimism.JaxConfig import *
43
from optimism import TensorMath
5-
from optimism.phasefield import SandiaModel as MaterialModel
4+
from optimism.phasefield import PhaseFieldLorentzPlastic as MaterialModel
65
from optimism.phasefield.MaterialPointSimulator import MaterialPointSimulator
76

87
from properties import props

examples/uniaxial/Uniaxial.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from functools import partial
2+
import jax
3+
from jax import numpy as np
14
from matplotlib import pyplot as plt
25

3-
from optimism.JaxConfig import *
46
from optimism import EquationSolver as EqSolver
57
from optimism import FunctionSpace
68
from optimism import Interpolants
@@ -82,8 +84,8 @@ def energy_function(self, Uu, p):
8284
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables)
8385

8486

85-
@partial(jit, static_argnums=0)
86-
@partial(value_and_grad, argnums=2)
87+
@partial(jax.jit, static_argnums=0)
88+
@partial(jax.value_and_grad, argnums=2)
8789
def compute_reactions_from_bcs(self, Uu, Ubc, internalVariables):
8890
U = self.dofManager.create_field(Uu, Ubc)
8991
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables)

examples/uniaxial_dynamic/UniaxialDynamic.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#import sys
1+
from jax import numpy as np
22

33
from optimism.JaxConfig import *
44
from optimism import EquationSolver
@@ -94,13 +94,6 @@ def energy_function(self, Uu, p):
9494
return self.dynamicsFunctions.compute_algorithmic_energy(U, UPredictor, internalVariables, dt)
9595

9696

97-
# @partial(jit, static_argnums=0)
98-
# @partial(value_and_grad, argnums=2)
99-
# def compute_reactions_from_bcs(self, Uu, Ubc, internalVariables):
100-
# U = self.dofManager.create_field(Uu, Ubc)
101-
# return self.dynamicsFunctions.compute_output_potential_energies_and_stresses(U, internalVariables)
102-
103-
10497
def create_field(self, Uu, p):
10598
return self.dofManager.create_field(Uu, self.get_ubcs(p))
10699

optimism/JaxConfig.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,9 @@
11
from functools import partial
22
from collections import namedtuple
3-
import warnings
4-
with warnings.catch_warnings():
5-
warnings.filterwarnings("ignore")
6-
import jax.numpy as np
7-
from jax import grad, jacrev, jacfwd, jvp, vjp, value_and_grad, hessian, ops, lax, make_jaxpr, linearize
8-
from jax import custom_jvp, custom_vjp
9-
from jax import jit as jaxJit
10-
from jax import vmap as jaxVmap
11-
from jax.lax import while_loop
12-
from jax.config import config
13-
14-
15-
config.update("jax_enable_x64", True)
16-
#config.update("jax_debug_nans", True)
17-
18-
19-
jaxDebug=False
20-
21-
if jaxDebug:
22-
def jit(f,static_argnums=None):
23-
return f
24-
25-
vmap = jaxVmap
26-
27-
def while_loop(cond_fun, body_fun, init_val):
28-
val = init_val
29-
while cond_fun(val):
30-
val = body_fun(val)
31-
return val
32-
else:
33-
jit = jaxJit
34-
vmap = jaxVmap
3+
import jax.numpy as np
4+
from jax import grad, jacrev, jacfwd, jvp, vjp, vmap, value_and_grad, hessian, ops, lax, make_jaxpr, linearize, jit
5+
from jax import custom_jvp, custom_vjp
6+
from jax.lax import while_loop
357

368

379
def if_then_else(cond, val1, val2):

optimism/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from jax import config
2+
3+
# force double precision floating point arithmetic
4+
# this is deprecated in jax, we'll have to find another way soon.
5+
config.update("jax_enable_x64", True)
6+
7+
# silence warnings about no gpu/tpu
8+
config.update("jax_platforms", "cpu")
9+
10+
# debugging options
11+
#config.update("jax_debug_nans", True)
12+
#config.update("jax_debug_infs", True)
13+
#config.update("jax_disable_jit", True)
14+
15+
del config

0 commit comments

Comments
 (0)