Skip to content

Open In Colab

Rigid Body Constraints¤

Colab Setup (Install Dependencies)
# Only run this if we are in Google Colab
if 'google.colab' in str(get_ipython()):
    print("Installing dependencies from pyproject.toml...")
    # This installs the repo itself (and its dependencies)
    !apt-get install gmsh 
    !apt-get install -qq xvfb libgl1-mesa-glx
    !pip install pyvista -qq
    !pip install -q "git+https://github.com/smec-ethz/tatva-docs.git"    
    print("Installation complete!")

In the example, we implement rigid body constraint to a nonlinear material using static condensation

from typing import NamedTuple

import gmsh
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array
from jax_autovmap import autovmap

from tatva import Mesh, Operator

jax.config.update("jax_enable_x64", True)
Code for generating a plate with a hole geometry and meshing it.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes


def extract_physical_groups(tag_map: dict) -> dict[str, np.ndarray]:
    print("Extracting physical groups from Gmsh model...")
    physical_surfaces: dict[str, np.ndarray] = {}

    for dim, pg_tag in gmsh.model.getPhysicalGroups(dim=1):
        name = gmsh.model.getPhysicalName(dim, pg_tag)

        # Entities (surface tags) that belong to this physical group
        entities = gmsh.model.getEntitiesForPhysicalGroup(dim, pg_tag)

        els = []
        for ent in entities:
            # Get all mesh elements on this surface entity
            types, _, node_tags_by_type = gmsh.model.mesh.getElements(dim, ent)

            for etype, ntags in zip(types, node_tags_by_type):
                nodes = np.array(ntags, dtype=np.int64).reshape(-1, etype + 1)
                els.append(nodes)

        if not els:
            physical_surfaces[name] = np.zeros((0, 2), dtype=np.int32)
            continue

        group_els = np.vstack(els, dtype=np.int32)
        group_els = np.array(
            [[tag_map[t] for t in tri] for tri in group_els], dtype=np.int32
        )
        physical_surfaces[name] = group_els

    return physical_surfaces



def plot_mesh(mesh: Mesh, ax: Axes | None = None) -> None:
    if ax is None:
        fig, ax = plt.subplots()
    ax.tripcolor(
        mesh.coords[:, 0],
        mesh.coords[:, 1],
        mesh.elements,
        facecolors=np.ones(len(mesh.elements)),
        cmap="managua",
        edgecolors="k",
        linewidth=0.2,
    )
    ax.margins(0, 0)
    ax.set_aspect("equal")

def generate_plate(length: float, height: float, hole_radius: float) -> Mesh:
    y0, y1 = 0 - height / 2, height / 2
    x0, x1 = 0.0, length
    gmsh.initialize()
    gmsh.model.add("plate_with_hole")
    occ = gmsh.model.occ

    # Define the geometry
    p1 = occ.addPoint(x0, y0, 0)
    p2 = occ.addPoint(x1, y0, 0)
    p3 = occ.addPoint(x1, y1, 0)
    p4 = occ.addPoint(x0, y1, 0)

    c1 = occ.addPoint(0, -hole_radius, 0)
    c2 = occ.addPoint(0, hole_radius, 0)
    cmid = occ.addPoint(0, 0, 0)
    c_hole = occ.add_circle_arc(c2, cmid, c1)

    l1 = occ.addLine(p1, p2)
    l2 = occ.addLine(p2, p3)
    l3 = occ.addLine(p3, p4)
    l4 = occ.addLine(p4, c2)
    l5 = c_hole
    l6 = occ.addLine(c1, p1)

    cl = occ.addCurveLoop([l1, l2, l3, l4, l5, l6])
    surface = gmsh.model.occ.addPlaneSurface([cl])

    gmsh.model.occ.synchronize()

    # Refine along the hole boundary
    base_size = min(length, height) / 30
    fine_size = base_size / 3
    dist_field = gmsh.model.mesh.field.add("Distance")
    gmsh.model.mesh.field.setNumbers(dist_field, "CurvesList", [c_hole])
    thresh_field = gmsh.model.mesh.field.add("Threshold")
    gmsh.model.mesh.field.setNumber(thresh_field, "InField", dist_field)
    gmsh.model.mesh.field.setNumber(thresh_field, "SizeMin", fine_size)
    gmsh.model.mesh.field.setNumber(thresh_field, "SizeMax", base_size)
    gmsh.model.mesh.field.setNumber(thresh_field, "DistMin", hole_radius * 0.25)
    gmsh.model.mesh.field.setNumber(thresh_field, "DistMax", hole_radius * 1.25)
    gmsh.model.mesh.field.setAsBackgroundMesh(thresh_field)

    # Mesh the geometry
    gmsh.model.mesh.generate(2)

    gmsh.model.addPhysicalGroup(2, [surface], 2, "plate_surface")
    gmsh.model.addPhysicalGroup(1, [c_hole], 1, "hole_boundary")
    gmsh.model.addPhysicalGroup(1, [l4, l6], 3, "left_edge")
    gmsh.model.addPhysicalGroup(1, [l2], 4, "right_edge")
    gmsh.model.addPhysicalGroup(1, [l1], 5, "bottom_edge")
    gmsh.model.addPhysicalGroup(1, [l3], 6, "top_edge")

    # Extract nodes and elements
    node_tags, node_coords, _ = gmsh.model.mesh.getNodes()
    tag_map = {tag: i for i, tag in enumerate(node_tags)}
    nodes = jnp.array(node_coords).reshape(-1, 3)[:, :2]

    elem_types, elem_tags, elem_node_tags = gmsh.model.mesh.getElements(2)
    elements = jnp.array(elem_node_tags[0]).reshape(-1, 3) - 1

    pg = extract_physical_groups(tag_map)
    gmsh.finalize()
    return Mesh(nodes, elements=elements), pg

Creating Mesh¤

We define a plate with a hole at its center. Because of symmetry we only model, one half of the plate.

lx = 1.0
ly = 2.0
r = 0.2
mesh, pg = generate_plate(lx, ly, r)
n_dofs_per_node = 2
n_nodes = mesh.coords.shape[0]
n_dofs = n_dofs_per_node * n_nodes
plot_mesh(mesh)
Info    : Meshing 1D...
Info    : [  0%] Meshing curve 1 (Circle)
Info    : [ 20%] Meshing curve 2 (Line)
Info    : [ 40%] Meshing curve 3 (Line)
Info    : [ 60%] Meshing curve 4 (Line)
Info    : [ 70%] Meshing curve 5 (Line)
Info    : [ 90%] Meshing curve 6 (Line)
Info    : Done meshing 1D (Wall 0.00501475s, CPU 0.005614s)
Info    : Meshing 2D...
Info    : Meshing surface 1 (Plane, Frontal-Delaunay)
Info    : Done meshing 2D (Wall 0.12732s, CPU 0.120661s)
Info    : 3518 nodes 7039 elements
Extracting physical groups from Gmsh model...

img

To define rigid body constraint, we define a rigid disk of radius = 0.15, which is attached to the top right part of the plate.

radius = 0.15
disk_center = jnp.array([lx * 3 / 4, ly / 2 * 3 / 4])
disk_nodes = jnp.where(
    jnp.linalg.norm(mesh.coords - disk_center, axis=1) < (radius + 1e-12)
)[0]

We can now define the Operator object which takes the mesh and the associated element type which in this case is Tri3.

from tatva.element import Tri3

op = Operator(mesh, Tri3())

Defining boundaries¤

We start with identify the degrees of freedom associated with boundaries

boundary_left = jnp.where(jnp.isclose(mesh.coords[:, 0], 0.0))[0]
boundary_right = jnp.where(jnp.isclose(mesh.coords[:, 0], lx))[0]
point_at_y_0 = jnp.where(
    jnp.isclose(mesh.coords[:, 0], lx) & jnp.isclose(mesh.coords[:, 1], 0.0)
)[0][0]
boundary_left_bottom = jnp.where(
    (jnp.isclose(mesh.coords[:, 0], 0.0)) & (mesh.coords[:, 1] < 0.0)
)[0]

fixed_dofs = jnp.concatenate(
    [
        boundary_left_bottom * n_dofs_per_node,  # u_x = 0 at left-bottom nodes
        jnp.array([n_dofs + 1]),
    ]
)

free_dofs = jnp.setdiff1d(
    jnp.arange(mesh.coords.shape[0] * n_dofs_per_node + 3), fixed_dofs
)

Defining constraints¤

To apply the symmetric boundary conditions and the rigid body constraints we define to functions that take the displacements and assign the appropriate boundary conditions.

@jax.jit
def apply_dirichlet_bc(u_flat: Array) -> Array:
    u_flat = u_flat.at[fixed_dofs].set(0.0)
    return u_flat


@jax.jit
def apply_rigid_body_bc(u_flat: Array) -> Array:
    ux_c, uy_c, theta_c = u_flat[-3:]
    vecs = mesh.coords[disk_nodes] - disk_center

    cos, sin = jnp.cos(theta_c), jnp.sin(theta_c)
    rotation_matrix = jnp.array([[cos, -sin], [sin, cos]])
    rotated_vecs = vecs @ rotation_matrix.T
    u_rotation = rotated_vecs - vecs
    u_rigid = jnp.stack([ux_c, uy_c]) + u_rotation
    # Set the displacements at the disk nodes
    u_flat = u_flat.at[disk_nodes * n_dofs_per_node].set(u_rigid[:, 0])
    u_flat = u_flat.at[disk_nodes * n_dofs_per_node + 1].set(u_rigid[:, 1])
    return u_flat

Defining Energy Functional¤

We are now in a position to define the total global energy functional

class Material(NamedTuple):
    """Material properties for the elasticity operator."""

    mu: float
    lmbda: float

    @classmethod
    def from_youngs_poisson_2d(
        cls, E: float, nu: float, plane_stress: bool = False
    ) -> "Material":
        mu = E / 2 / (1 + nu)
        if plane_stress:
            lmbda = 2 * nu * mu / (1 - nu)
        else:
            lmbda = E * nu / (1 - 2 * nu) / (1 + nu)
        return cls(mu=mu, lmbda=lmbda)


mat = Material.from_youngs_poisson_2d(1, 0.3)


# Hyperelastic material model
@autovmap(grad_u=2)
def compute_deformation_gradient(grad_u):
    return jnp.eye(2) + grad_u


@autovmap(grad_u=2, mu=0, lmbda=0)
def strain_energy_mpc(grad_u, mu, lmbda):
    F = compute_deformation_gradient(grad_u)
    C = F.T @ F
    J = jnp.linalg.det(F)
    return (
        mat.mu / 2 * (jnp.trace(C) - 2)  # 2D case
        - mat.mu * jnp.log(J)
        + (mat.lmbda / 2) * (jnp.log(J)) ** 2
    )


def total_energy_mpc_full(u_flat: Array) -> Array:
    """Compute the total energy of the system."""
    u = u_flat[:-3].reshape(-1, n_dofs_per_node)
    e_density = strain_energy_mpc(op.grad(u), mat.mu, mat.lmbda)
    return op.integrate(e_density)


def total_energy_mpc(u_free: Array) -> Array:
    """Compute the total energy of the system."""
    u_full = (
        jnp.zeros(n_dofs + 3).at[free_dofs].set(u_free)
    )  # +3 for the rigid body DOFs
    u_full = apply_dirichlet_bc(u_full)
    u_full = apply_rigid_body_bc(u_full)
    return total_energy_mpc_full(u_full)


residual_mpc = jax.jacrev(total_energy_mpc)
Newton-Krylov solver for the nonlinear system.
from functools import partial


@partial(jax.jit, static_argnames=["gradient","compute_tangent"])
def newton_krylov_solver(
    u,
    gradient,
    compute_tangent,
):
    residual = gradient(u)
    norm_res = jnp.linalg.norm(residual)

    init_val = (u, 0, norm_res)

    def cond_fun(state):
        u, iiter, norm_res = state
        return jnp.logical_and(norm_res > 1e-8, iiter < 10)

    def body_fun(state):
        u, iiter, norm_res = state
        residual = gradient(u)

        A = jax.jit(partial(compute_tangent, u_prev=u))

        du, _ = jax.scipy.sparse.linalg.cg(A=A, b=-residual)

        u = u + du

        residual = gradient(u)
        norm_res = jnp.linalg.norm(residual)

        return (u, iiter + 1, norm_res)

    final_u, final_iiter, final_norm = jax.lax.while_loop(cond_fun, body_fun, init_val)
    jax.debug.print("  Residual: {res:.2e}", res=final_norm)

    return final_u, final_norm

We move the disk in the right direction while keeping its y position fixed. The disk is allowed to rotate \(\theta\) such that the disk experiences no moment about its center.

u_disk_x_dof = -2  # hardcoded last 2 dofs correspond to u_disk (one is fixed by BC)


@jax.jit
def func_mpc(u_free: Array) -> Array:
    res = residual_mpc(u_free)
    return res.at[u_disk_x_dof].add(-0.02)

def compute_tangent(du, u_prev):
    tangent = jax.jvp(func_mpc, (u_prev,), (du,))[1]
    return tangent

Solving the system¤

We solve the system for the applied rigid body displacement.

f_ext = jnp.zeros(len(free_dofs))

u_sol, norm_res  = newton_krylov_solver(
    u=jnp.zeros_like(f_ext),
    gradient=func_mpc,
    compute_tangent=compute_tangent,
)
Residual: 8.96e-11
Visualization of the deformed configuration and strain energy density.
from matplotlib.collections import LineCollection
from matplotlib.tri import Triangulation

_u = jnp.zeros(n_dofs + 3).at[free_dofs].set(u_sol)
_u = apply_rigid_body_bc(_u)
u_solid = _u.at[:-3].get()
u_disk = _u.at[-3:].get()

u = u_solid.reshape(-1, n_dofs_per_node)

fig, ax = plt.subplots(figsize=(3.4, 3))
x_final = mesh.coords + u
tri = Triangulation(x_final[:, 0], x_final[:, 1], mesh.elements)

e_sig = strain_energy_mpc(op.grad(u), mat.mu, mat.lmbda).squeeze()
cb = ax.tripcolor(tri, e_sig, alpha=0.95, cmap="managua")
ax.set_aspect("equal")
plt.colorbar(cb, ax=ax, label=r"$\Psi_{\varepsilon}$")

edge_elems = jnp.concatenate(
    (
        pg["hole_boundary"],
        pg["left_edge"],
        pg["right_edge"],
        pg["bottom_edge"],
        pg["top_edge"],
    )
)
segments = mesh.coords[edge_elems][:, :, :2]
coll_boundary = LineCollection(
    segments + u[edge_elems][:, :, :2], colors="k", linewidths=0.5
)
ax.add_collection(coll_boundary)
ax.scatter(
    (mesh.coords[disk_nodes] + u[disk_nodes])[:, 0],
    (mesh.coords[disk_nodes] + u[disk_nodes])[:, 1],
    color="k",
    s=1,
    alpha=0.3,
)

rigid_circle = plt.Circle(
    disk_center + u_disk[:2],
    radius,
    color="k",
    fc="#0000003f",
    linewidth=0.7,
)
udx, udy, theta_d = u_disk
rigid_cross = plt.Line2D(
    [
        disk_center[0] + udx - np.cos(theta_d) * radius,
        disk_center[0] + udx + np.cos(theta_d) * radius,
    ],
    [
        disk_center[1] + udy - np.sin(theta_d) * radius,
        disk_center[1] + udy + np.sin(theta_d) * radius,
    ],
    color="k",
    linewidth=0.7,
)
rigid_cross_vertical = plt.Line2D(
    [
        disk_center[0] + udx - np.sin(theta_d) * radius,
        disk_center[0] + udx + np.sin(theta_d) * radius,
    ],
    [
        disk_center[1] + udy + np.cos(theta_d) * radius,
        disk_center[1] + udy - np.cos(theta_d) * radius,
    ],
    color="k",
    linewidth=0.7,
)
ax.add_line(rigid_cross)
ax.add_line(rigid_cross_vertical)
ax.add_artist(rigid_circle)

ax.grid()
ax.set(
    xlabel="$x$",
    ylabel="$y$",
)
plt.show()
#u_disk

img