Linear elasticity with SciPy¤
Colab Setup (Install Dependencies)
# Only run this if we are in Google Colab
if 'google.colab' in str(get_ipython()):
print("Installing dependencies using uv...")
# Install uv if not available
!pip install -q uv
# Install system dependencies
!apt-get install -qq gmsh
# Use uv to install Python dependencies
!uv pip install --system matplotlib meshio
!uv pip install --system "git+https://github.com/smec-ethz/tatva-docs.git"
print("Installation complete!")
In this notebook, we solve a linear elastic problem using scipy based solvers.
from typing import NamedTuple
import gmsh
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import Array
from jax_autovmap import autovmap
from tatva import Mesh, Operator, sparse
import scipy
Code for generating a plate with a hole and plotting the mesh
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import meshio
import numpy as np
import os
def plot_mesh(mesh: Mesh, ax: Axes | None = None) -> None:
if ax is None:
fig, ax = plt.subplots()
ax.tripcolor(
mesh.coords[:, 0],
mesh.coords[:, 1],
mesh.elements,
facecolors=np.ones(len(mesh.elements)),
cmap="managua",
edgecolors="k",
linewidth=0.2,
)
ax.set_aspect("equal")
ax.margins(0.0)
def generate_refined_plate_with_hole(
width: float,
height: float,
hole_radius: float,
mesh_size_fine: float,
mesh_size_coarse: float,
):
mesh_dir = os.path.join(os.getcwd(), "../meshes")
os.makedirs(mesh_dir, exist_ok=True)
output_filename = os.path.join(mesh_dir, "plate_hole_refined.msh")
gmsh.initialize()
gmsh.model.add("plate_with_hole_refined")
occ = gmsh.model.occ
rect = occ.addRectangle(0, 0, 0, width, height)
cx = width / 2.0
cy = height / 2.0
disk = occ.addDisk(cx, cy, 0, hole_radius, hole_radius)
out, _ = occ.cut([(2, rect)], [(2, disk)])
occ.synchronize()
surface_tag = out[0][1]
gmsh.model.addPhysicalGroup(2, [surface_tag], 1, name="domain")
boundaries = gmsh.model.getBoundary(out, oriented=False)
boundary_tags = [b[1] for b in boundaries]
gmsh.model.addPhysicalGroup(1, boundary_tags, 2, name="boundaries")
hole_curve_tags = []
for tag in boundary_tags:
xmin, ymin, zmin, xmax, ymax, zmax = gmsh.model.getBoundingBox(1, tag)
# The hole is completely inside the outer rectangle
if xmin > 0 and xmax < width and ymin > 0 and ymax < height:
hole_curve_tags.append(tag)
gmsh.model.mesh.field.add("Distance", 1)
gmsh.model.mesh.field.setNumbers(1, "CurvesList", hole_curve_tags)
gmsh.model.mesh.field.setNumber(1, "NumPointsPerCurve", 100)
gmsh.model.mesh.field.add("Threshold", 2)
gmsh.model.mesh.field.setNumber(2, "InField", 1)
gmsh.model.mesh.field.setNumber(2, "SizeMin", mesh_size_fine)
gmsh.model.mesh.field.setNumber(2, "SizeMax", mesh_size_coarse)
# Start growing the mesh exactly at the hole boundary
gmsh.model.mesh.field.setNumber(2, "DistMin", 0.0)
# Reach maximum element size at a distance equal to 2 hole radii away
gmsh.model.mesh.field.setNumber(2, "DistMax", hole_radius * 2.0)
gmsh.model.mesh.field.setAsBackgroundMesh(2)
gmsh.option.setNumber("Mesh.MeshSizeExtendFromBoundary", 0)
gmsh.option.setNumber("Mesh.MeshSizeFromPoints", 0)
gmsh.option.setNumber("Mesh.MeshSizeFromCurvature", 0)
gmsh.model.mesh.generate(2)
gmsh.write(output_filename)
gmsh.finalize()
_mesh = meshio.read(output_filename)
coords = _mesh.points[:, :2]
elements = _mesh.cells_dict["triangle"]
return Mesh(coords=coords, elements=elements)
lx = 1.0
ly = 1.0
mesh = generate_refined_plate_with_hole(
lx, ly, hole_radius=0.2, mesh_size_fine=0.01, mesh_size_coarse=0.05
)
n_dofs_per_node = 2
n_dofs = mesh.coords.shape[0] * n_dofs_per_node
plot_mesh(mesh)
Info : Meshing 1D...
Info : [ 0%] Meshing curve 5 (Ellipse)
Info : [ 30%] Meshing curve 6 (Line)
Info : [ 50%] Meshing curve 7 (Line)
Info : [ 70%] Meshing curve 8 (Line)
Info : [ 90%] Meshing curve 9 (Line)
Info : Done meshing 1D (Wall 0.0223751s, CPU 0.022052s)
Info : Meshing 2D...
Info : Meshing surface 1 (Plane, Frontal-Delaunay)
Info : Done meshing 2D (Wall 0.0601408s, CPU 0.052454s)
Info : 1874 nodes 3753 elements
Info : Writing '/home/mohit/Documents/dev/tatva-docs/docs/external_solvers/../meshes/plate_hole_refined.msh'...
Info : Done writing '/home/mohit/Documents/dev/tatva-docs/docs/external_solvers/../meshes/plate_hole_refined.msh'

Problem setup¤
from tatva.element import Tri3
op = Operator(mesh, Tri3())
boundary_left = jnp.where(jnp.isclose(mesh.coords[:, 0], 0.0))[0]
boundary_right = jnp.where(jnp.isclose(mesh.coords[:, 0], lx))[0]
point_at_y_0 = jnp.where(
jnp.isclose(mesh.coords[:, 0], lx) & jnp.isclose(mesh.coords[:, 1], 0.0)
)[0][0]
assert point_at_y_0
fixed_dofs = jnp.concatenate(
[
boundary_left * n_dofs_per_node,
]
)
free_dofs = jnp.setdiff1d(jnp.arange(n_dofs), fixed_dofs)
Defining energy functional¤
We now define the functions to compute the total strain energy
class Material(NamedTuple):
"""Material properties for the elasticity operator."""
mu: float
lmbda: float
@classmethod
def from_youngs_poisson_2d(
cls, E: float, nu: float, plane_stress: bool = False
) -> "Material":
mu = E / 2 / (1 + nu)
if plane_stress:
lmbda = 2 * nu * mu / (1 - nu)
else:
lmbda = E * nu / (1 - 2 * nu) / (1 + nu)
return cls(mu=mu, lmbda=lmbda)
mat = Material.from_youngs_poisson_2d(1, 0.3)
@autovmap(grad_u=2)
def compute_strain(grad_u):
return 0.5 * (grad_u + grad_u.T)
@autovmap(eps=2, mu=0, lmbda=0)
def compute_stress(eps, mu, lmbda):
return 2 * mu * eps + lmbda * jnp.trace(eps) * jnp.eye(2)
@autovmap(grad_u=2, mu=0, lmbda=0)
def strain_energy_density(grad_u, mu, lmbda):
eps = compute_strain(grad_u)
sigma = compute_stress(eps, mu, lmbda)
return 0.5 * jnp.einsum("ij,ij->", sigma, eps)
Enforcing boundary condition via static condensation¤
@jax.jit
def total_energy_full(u_flat: Array) -> Array:
"""Compute the total energy of the system."""
u = u_flat.reshape(-1, 2)
u_grad = op.grad(u)
e_density = strain_energy_density(u_grad, mat.mu, mat.lmbda)
return op.integrate(e_density)
@jax.jit
def total_energy(u_free: Array, applied_disp: Array) -> Array:
"""Compute the total energy of the system."""
u_full = jnp.zeros(n_dofs).at[free_dofs].set(u_free)
u_full = u_full.at[fixed_dofs].set(applied_disp)
return total_energy_full(u_full)
residual = jax.jacrev(total_energy)
Now we can define the sparsity_pattern and use sparse.jacfwd to create a function that computes the sparse stiffness matrix using sparse differentiation.
nb_dofs_per_node = 2
sparsity_pattern = sparse.create_sparsity_pattern(
mesh, n_dofs_per_node=nb_dofs_per_node
)
reduced_sparsity_pattern = sparse.reduce_sparsity_pattern(sparsity_pattern, free_dofs)
colored_matrix = sparse.ColoredMatrix.from_csr(reduced_sparsity_pattern)
hessian_fn = sparse.jacfwd(
fn=residual,
colored_matrix=colored_matrix,
color_batch_size=10,
)
Defining the loading traction on right edge¤
We define a new Operator consisting of line elements along the right edge and then use this op_line to integrate the traction along the nodes.
sig_loading = 1e-2
f_ext_0 = jnp.zeros(n_dofs)
idx_right = n_dofs_per_node * boundary_right
f_ext_0 = f_ext_0.at[idx_right].add(sig_loading)
f_ext = f_ext_0.at[free_dofs].get()
Now let us solve the linear problem using scipy sparse solver scipy.sparse.linalg.spsolve
\[
\mathbf{K}\boldsymbol{u} =\boldsymbol{f}_\text{ext}
\]
We will use the above defined hessian_fn to compute the sparse stiffness matrix and will convert it to scipy.sparse.csr_matrix.
K_sparse = hessian_fn(jnp.zeros(len(free_dofs)), applied_disp=jnp.zeros(len(fixed_dofs)))
du = scipy.sparse.linalg.spsolve(K_sparse.to_csr(), f_ext)
Visualization and analyzing the results¤
Code to visualize the results
from matplotlib.tri import Triangulation
u = jnp.zeros(n_dofs).at[free_dofs].set(du)
u = u.reshape(-1, 2)
fig, ax = plt.subplots(figsize=(7.4, 3))
x_final = mesh.coords + u
tri = Triangulation(x_final[:, 0], x_final[:, 1], mesh.elements)
sig = compute_stress(compute_strain(op.grad(u)), mat.mu, mat.lmbda).squeeze()
def plot_field(ax: Axes):
cb = ax.tripcolor(
tri,
sig[..., 0, 1],
alpha=0.95,
rasterized=True,
cmap="managua",
)
ax.set_aspect("equal")
ax.set(
xlabel="$x$",
ylabel="$y$",
)
return cb
cb = plot_field(ax)
ax.set_axis_off()
plt.colorbar(cb, ax=ax, label=r"$\sigma_{xy}$", shrink=0.7)
fig.tight_layout()
