J2 plasticity
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", "cpu")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
import jax.numpy as jnp
from jax import Array
import numpy as np
import equinox as eqx
import matplotlib.pyplot as plt
from functools import partial
from xpektra import (
SpectralSpace,
make_field,
)
from xpektra.scheme import RotatedDifference
from xpektra.projection_operator import GalerkinProjection
from xpektra.spectral_operator import SpectralOperator
from xpektra.transform import FFTTransform
from xpektra.solvers.nonlinear import newton_krylov_solver, conjugate_gradient_while
import random
random.seed(1)
def place_circle(matrix, n, r, x_center, y_center):
for i in range(n):
for j in range(n):
if (i - x_center) ** 2 + (j - y_center) ** 2 <= r**2:
matrix[i][j] = 1
def generate_matrix_with_circles(n, x, r):
if r >= n:
raise ValueError("Radius r must be less than the size of the matrix n")
matrix = np.zeros((n, n), dtype=int)
placed_circles = 0
while placed_circles < x:
x_center = random.randint(0, n - 1)
y_center = random.randint(0, n - 1)
# Check if the circle fits within the matrix bounds
if (
x_center + r < n
and y_center + r < n
and x_center - r >= 0
and y_center - r >= 0
):
previous_matrix = matrix.copy()
place_circle(matrix, n, r, x_center, y_center)
if not np.array_equal(previous_matrix, matrix):
placed_circles += 1
return matrix
N = 199
ndim = 2
length = 1.0
x = 10
r = 20
structure = generate_matrix_with_circles(N, x, r)
cb = plt.imshow(structure, cmap='viridis')
plt.colorbar(cb)
plt.show()

# Helper to map properties to grid
def map_prop(structure, val_soft, val_hard):
return val_hard * structure + val_soft * (1 - structure)
# Properties
phase_contrast = 2.0
K_field = map_prop(structure, 0.833, phase_contrast * 0.833)
mu_field = map_prop(structure, 0.386, phase_contrast * 0.386)
H_field = map_prop(structure, 0.01, phase_contrast * 0.01) # Normalized
sigma_y_field = map_prop(structure, 0.003, phase_contrast * 0.003) # Normalized
n_exponent = 1.0
fft_transform = FFTTransform(dim=ndim)
space = SpectralSpace(lengths=(length, length), shape=(N, N), transform=fft_transform)
scheme = RotatedDifference(space=space)
op = SpectralOperator(scheme=scheme, space=space)
Ghat = GalerkinProjection(scheme=scheme)
dofs_shape = make_field(dim=ndim, shape=structure.shape, rank=2).shape
# Pre-compute Identity Tensors for the grid
# I2: (N,N,2,2), I4_dev: (N,N,2,2,2,2)
i = jnp.eye(ndim)
I = make_field(dim=ndim, shape=structure.shape, rank=ndim) * i # Broadcasted Identity
II = op.dyad(I, I) # Fourth-order Identity
# I2 = jnp.eye(ndim) + jnp.zeros(space.shape + (ndim, ndim)) # Broadcast
# I4_sym = op.dyad(I2, I2) # Placeholder, construct proper I4s if needed
# Or better, just implement the math directly in the material class using tensor_op
# --- 4. J2 PLASTICITY MATERIAL MODEL ---
class J2Plasticity(eqx.Module):
"""
Encapsulates the J2 Plasticity constitutive law and return mapping.
"""
K: jax.Array
mu: jax.Array
H: jax.Array
sigma_y: jax.Array
n: float
def yield_stress(self, ep: jax.Array) -> jax.Array:
return self.sigma_y + self.H * (ep**self.n)
@eqx.filter_jit
def compute_response(
self, eps_total: jax.Array, state_prev: tuple[jax.Array, ...]
) -> tuple:
"""
Computes stress and new state variables given total strain and history.
state_prev = (eps_total_t, eps_elastic_t, ep_t)
"""
eps_t, epse_t, ep_t = state_prev
# Trial State (assume elastic step)
# Delta eps = eps_total - eps_t
# Trial elastic strain = old elastic strain + Delta eps
epse_trial = epse_t + (eps_total - eps_t)
jax.debug.print("epse_trial: {}", epse_trial.shape)
# Volumetric / Deviatoric Split, 2D plane strain
trace_epse = op.trace(epse_trial)
epse_dev = epse_trial - (trace_epse[..., None, None] / 2.0) * jnp.eye(2)
jax.debug.print("trace_epse: {}", epse_dev.shape)
# Note: Be careful with 2D trace. If plane strain, tr=e11+e22.
# If plane stress, e33 is non-zero. Assuming plane strain for simplicity.
# Trial Stress
# sigma_vol = K * trace_epse * I
# sigma_dev = 2 * mu * epse_dev
sigma_vol = self.K[..., None, None] * trace_epse[..., None, None] * jnp.eye(2)
sigma_dev = 2.0 * self.mu[..., None, None] * epse_dev
sigma_trial = sigma_vol + sigma_dev
# Mises Stress
# sig_eq = sqrt(3/2 * s:s)
norm_s = jnp.sqrt(op.ddot(sigma_dev, sigma_dev))
sig_eq_trial = jnp.sqrt(1.5) * norm_s
# 2. Check Yield Condition
sig_y_current = self.yield_stress(ep_t)
phi = sig_eq_trial - sig_y_current
# 3. Return Mapping (if plastic)
# Mask for plastic points
is_plastic = phi > 0
# Plastic Multiplier Delta_gamma
# Denom = 3*mu + H
denom = 3.0 * self.mu + self.H # (Linear hardening H' = H)
d_gamma = jnp.where(is_plastic, phi / denom, 0.0)
# Update State
# Normal vector n = s_trial / |s_trial|
# s_new = s_trial - 2*mu*d_gamma * n
# This simplifies to scaling s_trial
scale_factor = jnp.where(
is_plastic, 1.0 - (3.0 * self.mu * d_gamma) / sig_eq_trial, 1.0
)
sigma_dev_new = sigma_dev * scale_factor[..., None, None]
sigma_new = sigma_vol + sigma_dev_new
# Update plastic strain
ep_new = ep_t + d_gamma
# Update elastic strain (back-calculate from stress)
# eps_e_new = eps_e_trial - d_gamma * n * sqrt(3/2) ...
# Easier: eps_e_new = C_inv : sigma_new
# Or just update deviatoric part
epse_dev_new = epse_dev * scale_factor[..., None, None]
epse_vol_new = trace_epse[..., None, None] * jnp.eye(2) # Volumetric is elastic
epse_new = epse_dev_new + epse_vol_new
return sigma_new, (eps_total, epse_new, ep_new)
# Instantiate Material
material = J2Plasticity(K_field, mu_field, H_field, sigma_y_field, n_exponent)
# --- 5. RESIDUAL & JACOBIAN ---
class PlasticityResidual(eqx.Module):
material: J2Plasticity
state_prev: tuple # (eps_t, epse_t, ep_t)
dofs_shape: tuple = eqx.field(static=True)
def __call__(self, eps_total_flat):
# Reshape
eps_total = eps_total_flat.reshape(self.dofs_shape)
# Compute Stress (Physics)
# We discard the new state here, we only need stress for residual
sigma, _ = self.material.compute_response(eps_total, self.state_prev)
# Compute Residual (Equilibrium)
sigma_hat = op.forward(sigma)
res_hat = Ghat.project(sigma_hat)
res = op.inverse(res_hat).real
return res.reshape(-1)
class PlasticityJacobian(eqx.Module):
residual_fn: PlasticityResidual # Holds all necessary data
def __call__(self, deps_flat):
# JAX JVP does the heavy lifting!
# It automatically linearizes the return mapping algorithm
# to give the Consistent Algorithmic Tangent Operator.
# We differentiate residual_fn w.r.t its input (eps_total)
# evaluated at the current guess (which is baked into residual_fn if we used Partial,
# but here we need the current 'eps' point).
# ISSUE: Jacobian needs the current 'eps' point to evaluate the tangent.
# The solver passes 'deps', assuming J(eps) * deps.
# We need to refactor slightly to pass 'eps' into the Jacobian constructor
# or use the 'linearize' approach in the solver.
pass
# --- 6. MAIN SIMULATION LOOP ---
# Initialize Fields
# Layout: (N, N, 2, 2)
eps_total = make_field(dim=ndim, shape=structure.shape, rank=2)
eps_elastic = make_field(dim=ndim, shape=structure.shape, rank=2)
ep_accum = make_field(dim=ndim, shape=structure.shape, rank=0) # Scalar plastic strain
state_current = (eps_total, eps_elastic, ep_accum)
# History storage
stress_history = []
# Load steps
n_steps = 20
max_strain = 0.02
strain_steps = jnp.linspace(0, max_strain, n_steps)
eps_macro_inc = make_field(dim=ndim, shape=structure.shape, rank=2)
print("Starting Plasticity Simulation...")
for step, eps_val in enumerate(strain_steps[1:]):
# 1. Define Macroscopic Strain Increment
# Pure shear loading
delta_eps = eps_val - strain_steps[step]
eps_macro_inc[..., 0, 1] = delta_eps
eps_macro_inc[..., 1, 0] = delta_eps
# Predictor (Initial Guess): eps_new = eps_old + deps_macro
eps_guess = state_current[0] + eps_macro_inc
# 2. Setup Solver Functions
# We bind the *previous* state history to the residual function
residual_fn = PlasticityResidual(material, state_current, eps_guess.shape)
# 3. Solve Equilibrium (Newton-Krylov)
# The solver needs a function f(x) -> b and a linear operator J(dx) -> db
# We use jax.linearize to get the Jacobian at the current guess 'x'
# Initial residual
b0 = -residual_fn(eps_guess.reshape(-1))
# Update guess
eps_iter = eps_guess
# Newton Loop (Manual or use xpektra solver)
for i in range(10): # Newton iterations
# Linearize residual at current eps_iter
# val is residual(eps_iter), lin_fn is the Jacobian operator J(dx)
val, lin_fn = jax.linearize(residual_fn, eps_iter.reshape(-1))
if jnp.linalg.norm(val) < 1e-8:
print(f"Step {step}: Converged in {i} iterations.")
break
# Solve J * dx = -val using CG
dx, _ = conjugate_gradient_while(lin_fn, -val, max_iter=50, atol=1e-5)
# Update
eps_iter = eps_iter + dx.reshape(eps_iter.shape)
# 4. Update State History
# Now we call the material one last time with the converged strain
# to get the official new internal variables
final_sigma, state_current = material.compute_response(eps_iter, state_current)
# Store results
avg_stress = jnp.mean(final_sigma, axis=(0, 1))
stress_history.append(avg_stress[0, 1])
# Plot
plt.plot(strain_steps[1:], stress_history, "-o")
plt.xlabel("Macroscopic Shear Strain")
plt.ylabel("Macroscopic Shear Stress")
plt.title("J2 Plasticity: Stress-Strain Curve")
plt.grid()
plt.show()
Starting Plasticity Simulation...
epse_trial: (Array(199, dtype=int64), Array(199, dtype=int64), Array(2, dtype=int64), Array(2, dtype=int64))
trace_epse: (Array(199, dtype=int64), Array(199, dtype=int64), Array(2, dtype=int64), Array(2, dtype=int64))
epse_trial: (Array(199, dtype=int64), Array(199, dtype=int64), Array(2, dtype=int64), Array(2, dtype=int64))
trace_epse: (Array(199, dtype=int64), Array(199, dtype=int64), Array(2, dtype=int64), Array(2, dtype=int64))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[19], line 207
204 break
206 # Solve J * dx = -val using CG
--> 207 dx, _ = conjugate_gradient_while(lin_fn, -val, max_iter=50, atol=1e-5)
209 # Update
210 eps_iter = eps_iter + dx.reshape(eps_iter.shape)
[... skipping hidden 1 frame]
File ~/Documents/dev/spectralsolvers/.venv/lib/python3.12/site-packages/equinox/_jit.py:263, in _call(jit_wrapper, is_lower, args, kwargs)
259 marker, _, _ = out = jit_wrapper._cached(
260 dynamic_donate, dynamic_nodonate, static
261 )
262 else:
--> 263 marker, _, _ = out = jit_wrapper._cached(
264 dynamic_donate, dynamic_nodonate, static
265 )
266 # We need to include the explicit `isinstance(marker, jax.Array)` check due
267 # to https://github.com/patrick-kidger/equinox/issues/988
268 if not isinstance(marker, jax.core.Tracer) and isinstance(
269 marker, jax.Array
270 ):
ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'tuple'>, (((<function conjugate_gradient_while at 0x7012d45d6c00>,), PyTreeDef(*)), ((None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, TypedNdArray([[[ 0.00000000e+00 +0.j ,
0.00000000e+00 +0.j ],
[ 0.00000000e+00 +0.j ,
-9.91837639e-02 +6.2821414j ],
[ 0.00000000e+00 +0.j ,
-3.96636187e-01+12.55802063j],
...,
[ 0.00000000e+00 +0.j ,
-8.92060762e-01-18.82138175j],
[ 0.00000000e+00 +0.j ,
-3.96636187e-01-12.55802063j],
[ 0.00000000e+00 +0.j ,
-9.91837639e-02 -6.2821414j ]],
[[-9.91837639e-02 +6.2821414j ,
0.00000000e+00 +0.j ],
[-1.98318094e-01 +6.27901032j,
-1.98318094e-01 +6.27901032j],
[-2.97304170e-01 +6.27275126j,
-5.94756593e-01+12.54863049j],
...,
[ 1.98120406e-01 +6.27275126j,
-5.94756593e-01-18.83077189j],
[ 9.91343296e-02 +6.27901032j,
-1.98318094e-01-12.56115172j],
[ 1.64497329e-18 +6.2821414j ,
-1.64497329e-18 -6.2821414j ]],
[[-3.96636187e-01+12.55802063j,
0.00000000e+00 +0.j ],
[-5.94756593e-01+12.54863049j,
-2.97304170e-01 +6.27275126j],
[-7.92481820e-01+12.53299065j,
-7.92481820e-01+12.53299065j],
...,
[ 1.98120406e-01+12.54863049j,
-2.97304170e-01-18.83077189j],
[ 2.03607527e-17+12.55802063j,
-2.03607527e-17-12.55802063j],
[-1.98318094e-01+12.56115172j,
9.91343296e-02 -6.27901032j]],
...,
[[-8.92060762e-01-18.82138175j,
0.00000000e+00 +0.j ],
[-5.94756593e-01-18.83077189j,
1.98120406e-01 +6.27275126j],
[-2.97304170e-01-18.83077189j,
1.98120406e-01+12.54863049j],
...,
[-1.78012267e+00-18.73701082j,
-1.78012267e+00-18.73701082j],
[-1.48503935e+00-18.77447784j,
-9.89614772e-01-12.51111672j],
[-1.18892032e+00-18.80261083j,
-3.96043321e-01 -6.26337048j]],
[[-3.96636187e-01-12.55802063j,
0.00000000e+00 +0.j ],
[-1.98318094e-01-12.56115172j,
9.91343296e-02 +6.27901032j],
[ 2.03607527e-17-12.55802063j,
-2.03607527e-17+12.55802063j],
...,
[-9.89614772e-01-12.51111672j,
-1.48503935e+00-18.77447784j],
[-7.92481820e-01-12.53299065j,
-7.92481820e-01-12.53299065j],
[-5.94756593e-01-12.54863049j,
-2.97304170e-01 -6.27275126j]],
[[-9.91837639e-02 -6.2821414j ,
0.00000000e+00 +0.j ],
[ 1.64497329e-18 -6.2821414j ,
-1.64497329e-18 +6.2821414j ],
[ 9.91343296e-02 -6.27901032j,
-1.98318094e-01+12.56115172j],
...,
[-3.96043321e-01 -6.26337048j,
-1.18892032e+00-18.80261083j],
[-2.97304170e-01 -6.27275126j,
-5.94756593e-01-12.54863049j],
[-1.98318094e-01 -6.27901032j,
-1.98318094e-01 -6.27901032j]]], dtype=complex128)), PyTreeDef(CustomNode(Partial[_HashableCallableShim(functools.partial(<function _lift_linearized at 0x7012d6518040>, let _where = { lambda ; a:bool[199,199] b:f64[199,199] c:f64[199,199]. let
d:f64[199,199] = select_n a b c
in (d,) } in
{ lambda e:bool[199,199,2,2] f:f64[199,199,2,2] g:f64[1,1,2,2] h:f64[199,199,1,1]
i:f64[1,1,2,2] j:f64[199,199,1,1] k:f64[199,199,2,2] l:f64[199,199] m:f64[] n:f64[199,199]
o:bool[199,199] p:f64[199,199] q:f64[199,199] r:f64[199,199] s:f64[199,199] t:f64[199,199]
u:f64[199,199] v:f64[199,199,1,1] w:c128[199,199,2] x:c128[199,199,2]; y:f64[158404]. let
z:f64[199,199,2,2] = reshape[
dimensions=None
new_sizes=(199, 199, 2, 2)
sharding=None
] y
ba:f64[199,199,2,2] = jit[
name=compute_response
jaxpr={ lambda ; e:bool[199,199,2,2] f:f64[199,199,2,2] g:f64[1,1,2,2] h:f64[199,199,1,1]
i:f64[1,1,2,2] j:f64[199,199,1,1] k:f64[199,199,2,2] l:f64[199,199] m:f64[]
n:f64[199,199] o:bool[199,199] p:f64[199,199] q:f64[199,199] r:f64[199,199]
s:f64[199,199] t:f64[199,199] u:f64[199,199] v:f64[199,199,1,1] z:f64[199,199,2,2]. let
bb:f64[199,199] = jit[
name=trace
jaxpr={ lambda ; e:bool[199,199,2,2] f:f64[199,199,2,2] z:f64[199,199,2,2]. let
bb:f64[199,199] = jit[
name=trace
jaxpr={ lambda ; e:bool[199,199,2,2] f:f64[199,199,2,2] z:f64[199,199,2,2]. let
bc:f64[199,199,2,2] = select_n e f z
bb:f64[199,199] = reduce_sum[
axes=(2, 3)
out_sharding=None
] bc
in (bb,) }
] e f z
in (bb,) }
] e f z
bd:f64[199,199,1,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(199, 199, 1, 1)
sharding=None
] bb
be:f64[199,199,1,1] = div bd 2.0:f64[]
bf:f64[199,199,2,2] = mul be g
bg:f64[199,199,2,2] = sub z bf
bh:f64[199,199,1,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(199, 199, 1, 1)
sharding=None
] bb
bi:f64[199,199,1,1] = mul h bh
bj:f64[199,199,2,2] = mul bi i
bk:f64[199,199,2,2] = mul j bg
bl:f64[199,199] = jit[
name=ddot
jaxpr={ lambda ; k:f64[199,199,2,2] bm:f64[199,199,2,2] bk:f64[199,199,2,2]
bn:f64[199,199,2,2]. let
bl:f64[199,199] = jit[
name=ddot
jaxpr={ lambda ; k:f64[199,199,2,2] bm:f64[199,199,2,2] bk:f64[199,199,2,2]
bn:f64[199,199,2,2]. let
bo:f64[199,199] = dot_general[
dimension_numbers=(([2, 3], [3, 2]), ([0, 1], [0, 1]))
preferred_element_type=float64
] bk k
bp:f64[199,199] = dot_general[
dimension_numbers=(([2, 3], [3, 2]), ([0, 1], [0, 1]))
preferred_element_type=float64
] bm bn
bl:f64[199,199] = add_any bo bp
in (bl,) }
] k bm bk bn
in (bl,) }
] k k bk bk
bq:f64[199,199] = mul bl l
br:f64[199,199] = mul m bq
bs:f64[199,199] = div br n
bt:f64[199,199] = jit[name=_where jaxpr=_where] o p bs
bu:f64[199,199] = mul q bt
bv:f64[199,199] = div bu r
bw:f64[199,199] = neg br
bx:f64[199,199] = mul bw s
by:f64[199,199] = mul bx t
bz:f64[199,199] = add_any bv by
ca:f64[199,199] = neg bz
cb:f64[199,199] = jit[name=_where jaxpr=_where] o u ca
cc:f64[199,199,1,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(199, 199, 1, 1)
sharding=None
] cb
cd:f64[199,199,2,2] = mul bk v
ce:f64[199,199,2,2] = mul k cc
cf:f64[199,199,2,2] = add_any cd ce
ba:f64[199,199,2,2] = add bj cf
in (ba,) }
] e f g h i j k l m n o p q r s t u v z
cg:c128[199,199,2,2] = jit[
name=forward
jaxpr={ lambda ; ba:f64[199,199,2,2]. let
cg:c128[199,199,2,2] = jit[
name=forward
jaxpr={ lambda ; ba:f64[199,199,2,2]. let
ch:f64[2,2,199,199] = transpose[permutation=(2, 3, 0, 1)] ba
ci:c128[2,2,199,199] = jit[
name=fft
jaxpr={ lambda ; ch:f64[2,2,199,199]. let
cj:c128[2,2,199,199] = convert_element_type[
new_dtype=complex128
weak_type=False
] ch
ci:c128[2,2,199,199] = fft[
fft_lengths=(199, 199)
fft_type=0
] cj
in (ci,) }
] ch
cg:c128[199,199,2,2] = transpose[permutation=(2, 3, 0, 1)] ci
in (cg,) }
] ba
in (cg,) }
] ba
ck:c128[199,199,2,2] = jit[
name=project
jaxpr={ lambda ; w:c128[199,199,2] x:c128[199,199,2] cg:c128[199,199,2,2]. let
cl:c128[199,199,2] = dot_general[
dimension_numbers=(([2], [3]), ([0, 1], [0, 1]))
preferred_element_type=complex128
] w cg
ck:c128[199,199,2,2] = dot_general[
dimension_numbers=(([], []), ([0, 1], [0, 1]))
preferred_element_type=complex128
] cl x
in (ck,) }
] w x cg
cm:f64[199,199,2,2] = jit[
name=inverse
jaxpr={ lambda ; ck:c128[199,199,2,2]. let
cn:c128[199,199,2,2] = jit[
name=inverse
jaxpr={ lambda ; ck:c128[199,199,2,2]. let
co:c128[2,2,199,199] = transpose[permutation=(2, 3, 0, 1)] ck
cp:c128[2,2,199,199] = jit[
name=fft
jaxpr={ lambda ; co:c128[2,2,199,199]. let
cp:c128[2,2,199,199] = fft[
fft_lengths=(199, 199)
fft_type=1
] co
in (cp,) }
] co
cn:c128[199,199,2,2] = transpose[permutation=(2, 3, 0, 1)] cp
in (cn,) }
] ck
cm:f64[199,199,2,2] = real cn
in (cm,) }
] ck
cq:f64[158404] = reshape[dimensions=None new_sizes=(158404,) sharding=None] cm
in (cq,) }, [ShapedArray(float64[158404])], (PyTreeDef((*,)), PyTreeDef(*)), [(ShapedArray(float64[158404]), None)]))], [([*, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *, *],), {}]))), ((None, 1e-05, 50), PyTreeDef(((*, *, *, None), {})))). The error was:
TypeError: unhashable type: 'TypedNdArray'
ep_accum.shape
(199, 199, 2)
import jax
jax.config.update("jax_enable_x64", True) # use double-precision
jax.config.update("jax_platforms", "cpu")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
import jax.numpy as jnp
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
from skimage.morphology import disk, ellipse, rectangle
from spectralsolver import (
DifferentialMode,
SpectralSpace,
TensorOperator,
make_field,
)
from spectralsolver.operators import fourier_galerkin
from typing import Callable
In this notebook, we implement the small-strain J2 plasticity. We use the automatic differentiation to compute the alogriothmic tangent stiffness matrix (which ia function of bbothe elastic strain and plastic strain and yield function). It is necessary for Newton-Raphson iteration.
constructing an RVE¤
import random
random.seed(1)
def place_circle(matrix, n, r, x_center, y_center):
for i in range(n):
for j in range(n):
if (i - x_center) ** 2 + (j - y_center) ** 2 <= r**2:
matrix[i][j] = 1
def generate_matrix_with_circles(n, x, r):
if r >= n:
raise ValueError("Radius r must be less than the size of the matrix n")
matrix = np.zeros((n, n), dtype=int)
placed_circles = 0
while placed_circles < x:
x_center = random.randint(0, n - 1)
y_center = random.randint(0, n - 1)
# Check if the circle fits within the matrix bounds
if (
x_center + r < n
and y_center + r < n
and x_center - r >= 0
and y_center - r >= 0
):
previous_matrix = matrix.copy()
place_circle(matrix, n, r, x_center, y_center)
if not np.array_equal(previous_matrix, matrix):
placed_circles += 1
return matrix
# Example usage
N = 199
shape = (N, N)
length = 1.0
ndim = 2
x = 10
r = 20
structure = generate_matrix_with_circles(N, x, r)
grid_size = (N,) * ndim
elasticity_dof_shape = (ndim, ndim) + grid_size
assigning material parameters¤
We assign material parameters to the two phases. The two phases within the RVE are denoted as - Soft = 0 - Hard = 1
# material parameters + function to convert to grid of scalars
def param(X, soft, hard):
return hard * jnp.ones_like(X) * (X) + soft * jnp.ones_like(X) * (1 - X)
We consider a linear isotropic hardening law for both the phases
# material parameters
phase_constrast = 2
K = param(structure, soft=0.833, hard=phase_constrast * 0.833) # bulk modulus
μ = param(structure, soft=0.386, hard=phase_constrast * 0.386) # shear modulus
H = param(
structure, soft=2000.0e6 / 200.0e9, hard=phase_constrast * 2000.0e6 / 200.0e9
) # hardening modulus
sigma_y = param(
structure, soft=600.0e6 / 200.0e9, hard=phase_constrast * 600.0e6 / 200.0e9
) # initial yield stress
n = 1.0
plasticity basics¤
Now we define the basics of plasticity implementation:
- yield surface
\[
\Phi(\sigma_{ij}, \varepsilon^p_{ij}) = \underbrace{\sqrt{\dfrac{3}{2}\sigma^{dev}_{ij}\sigma^{dev}_{jk}}}_{\sigma^{eq}} - (\sigma_{0} + H\varepsilon^{p})
\]
- return mappping algorithm
\[
\Delta \varepsilon = \dfrac{\langle \Phi(\sigma_{ij}, \varepsilon_{p}) \rangle_{+}}{3\mu + H}
\]
- tangent stiffness operator
\[
\mathbb{C} = \dfrac{\partial \sigma^{t+1}}{\partial \varepsilon^{t+1}}
\]
We also define certain Identity tensor for each grid point.
- \(\mathbf{I}\) = 2 order Identity tensor with shape
(2, 2, N, N) - \(\mathbb{I4}\) = 4 order Identity tensor with shape
(2, 2, 2, 2, N, N)
tensor = TensorOperator(dim=ndim)
space = SpectralSpace(size=N, dim=ndim, length=length)
# identity tensor (single tensor)
i = jnp.eye(ndim)
# identity tensors (grid)
I = jnp.einsum(
"ij,xy",
i,
jnp.ones(
[
N,
]
* ndim
),
) # 2nd order Identity tensor
I4 = jnp.einsum(
"ijkl,xy->ijklxy",
jnp.einsum("il,jk", i, i),
jnp.ones(
[
N,
]
* ndim
),
) # 4th order Identity tensor
I4rt = jnp.einsum(
"ijkl,xy->ijklxy",
jnp.einsum("ik,jl", i, i),
jnp.ones(
[
N,
]
* ndim
),
)
I4s = (I4 + I4rt) / 2.0
II = tensor.dyad(I, I)
I4d = I4s - II / 3.0
Ghat = fourier_galerkin.compute_projection_operator(
space=space, diff_mode=DifferentialMode.rotated_difference
)
import equinox as eqx
@jax.jit
def yield_function(ep: jnp.ndarray):
return sigma_y + H * ep**n
@jax.jit
def compute_stress(eps: jnp.ndarray, args: tuple):
eps_t, epse_t, ep_t = args
# elastic stiffness tensor
C4e = K * II + 2.0 * μ * I4d
# trial state
epse_s = epse_t + (eps - eps_t)
sig_s = tensor.ddot(C4e, epse_s)
sigm_s = tensor.ddot(sig_s, I) / 3.0
sigd_s = sig_s - sigm_s * I
sigeq_s = jnp.sqrt(3.0 / 2.0 * tensor.ddot(sigd_s, sigd_s))
# avoid zero division below ("phi_s" is corrected below)
Z = jnp.where(sigeq_s == 0, True, False)
sigeq_s = jnp.where(Z, 1, sigeq_s)
# evaluate yield surface, set to zero if elastic (or stress-free)
sigy = yield_function(ep_t)
phi_s = sigeq_s - sigy
phi_s = 1.0 / 2.0 * (phi_s + jnp.abs(phi_s))
phi_s = jnp.where(Z, 0.0, phi_s)
elastic_pt = jnp.where(phi_s <= 0, True, False)
# plastic multiplier, based on non-linear hardening
# - initialize
dep = phi_s / (3 * μ + H)
# return map algorithm
N = 3.0 / 2.0 * sigd_s / sigeq_s
ep = ep_t + dep
sig = sig_s - dep * N * 2.0 * μ
epse = epse_s - dep * N
return sig, epse, ep
@eqx.filter_jit
def compute_residual(sigma: jnp.ndarray) -> jnp.ndarray:
return jnp.real(space.ifft(tensor.ddot(Ghat, space.fft(sigma)))).reshape(-1)
@eqx.filter_jit
def compute_tangents(deps: jnp.ndarray, args: tuple):
deps = deps.reshape(ndim, ndim, N, N)
eps, eps_t, epse_t, ep_t = args
primal, tangents = jax.jvp(
partial(compute_stress, args=(eps_t, epse_t, ep_t)), (eps,), (deps,)
)
return compute_residual(tangents[0])
# partial_compute_tangent = partial(compute_tangents, sigma=sigma)
from spectralsolver.solvers.nonlinear import (
conjugate_gradient_while,
newton_krylov_solver,
)
@jax.jit
def newton_solver(state, n):
deps, b, eps, eps_t, epse_t, ep_t, En, sig = state
error = jnp.linalg.norm(deps) / En
jax.debug.print("residual={}", jnp.linalg.norm(deps) / En)
def true_fun(state):
deps, b, eps, eps_t, epse_t, ep_t, En, sig = state
partial_compute_tangent = jax.jit(
partial(compute_tangents, args=(eps, eps_t, epse_t, ep_t))
)
deps, iiter = conjugate_gradient_while(
atol=1e-8,
A=partial_compute_tangent,
b=b,
) # solve linear system using CG
deps = deps.reshape(eps.shape)
eps = jax.lax.add(eps, deps) # update DOFs (array -> tensor.grid)
sig, epse, ep = compute_stress(eps, (eps_t, epse_t, ep_t))
b = -compute_residual(sig) # compute residual
jax.debug.print("CG iteration {}", iiter)
return (deps, b, eps, eps_t, epse, ep, En, sig)
def false_fun(state):
return state
return jax.lax.cond(error > 1e-8, true_fun, false_fun, state), n
# initialize: stress and strain tensor, and history
sig = make_field(dim=ndim, N=N, rank=2)
eps = make_field(dim=ndim, N=N, rank=2)
eps_t = make_field(dim=ndim, N=N, rank=2)
epse_t = make_field(dim=ndim, N=N, rank=2)
ep_t = make_field(dim=ndim, N=N, rank=2)
deps = make_field(dim=ndim, N=N, rank=2)
# define incremental macroscopic strain
ninc = 100
epsbar = 0.12
deps[0, 0] = jnp.sqrt(3.0) / 2.0 * epsbar / float(ninc)
deps[1, 1] = -jnp.sqrt(3.0) / 2.0 * epsbar / float(ninc)
b = -compute_tangents(deps, (eps, eps_t, epse_t, ep_t))
eps = jax.lax.add(eps, deps)
En = jnp.linalg.norm(eps)
state = (deps, b, eps, eps_t, epse_t, ep_t, En, sig)
final_state, xs = jax.lax.scan(newton_solver, init=state, xs=jnp.arange(0, 10))
residual=1.0
CG iteration 18
residual=0.30347510072366196
CG iteration 0
residual=0.0
residual=0.0
residual=0.0
residual=0.0
residual=0.0
residual=0.0
residual=0.0
residual=0.0