Skip to content

Open In Colab

Neural Operator Element Method¤

Colab Setup (Install Dependencies)
# Only run this if we are in Google Colab
if 'google.colab' in str(get_ipython()):
    print("Installing dependencies from pyproject.toml...")
    # This installs the repo itself (and its dependencies)
    !apt-get install gmsh 
    !apt-get install -qq xvfb libgl1-mesa-glx
    !pip install pyvista -qq
    !pip install -q "git+https://github.com/smec-ethz/tatva-docs.git"

    import pyvista as pv

    pv.global_theme.jupyter_backend = 'static'
    pv.global_theme.notebook = True
    pv.start_xvfb()

    print("Installation complete!")
else:
    import pyvista as pv
    pv.global_theme.jupyter_backend = 'client'

In this notebook, we will implement a neural constitutive model. A neural constitutive model uses neural networks to represent the relationship between stress and strain in materials. This approach allows for more flexible and accurate modeling of complex material behaviors compared to traditional constitutive models.

import os

import jax

jax.config.update("jax_enable_x64", True)  # Use double-precision for FEM stability

import equinox as eqx
import jax.experimental.sparse as jsp
import jax.numpy as jnp
import numpy as np
import pyvista as pv
import scipy.sparse as sp
from jax import Array
from jax_autovmap import autovmap
from tatva import Mesh, Operator, element, sparse
from tatva_coloring import distance2_color_and_seeds

Mesh and Material Setup¤

We start by defining the mesh and material properties for our simulation.

Mesh Generation
import gmsh
import meshio


def create_unstructured_3d_through_hole_mesh(
    L=10.0, H=5.0, a=3.0, mesh_size=0.8, filename="noem_3d_hole.msh"
):
    """
    Creates an unstructured tetrahedral mesh for a cuboid with a
    through-hole along the Z-axis.

    Parameters:
    - L: Width/Length of the cuboid (X and Y).
    - H: Height of the cuboid (Z-axis).
    - a: Side of the square hole.
    - mesh_size: Characteristic mesh size.
    """
    gmsh.initialize()
    gmsh.model.add("NOEM_ThroughHole")
    occ = gmsh.model.occ

    outer_vol = occ.addBox(-L / 2, -L / 2, 0, L, L, H)
    cutter_vol = occ.addBox(-a / 2, -a / 2, -0.1, a, a, H + 0.2)

    fem_vol, _ = occ.cut([(3, outer_vol)], [(3, cutter_vol)])
    occ.synchronize()

    all_surfaces = gmsh.model.getEntities(2)
    interface_surfaces = []

    for dim, tag in all_surfaces:
        mass_prop = occ.getCenterOfMass(dim, tag)
        # Check if surface is on the internal walls (x or y = +/- a/2)
        is_internal_x = (
            np.isclose(np.abs(mass_prop[0]), a / 2, atol=1e-3)
            and np.abs(mass_prop[1]) <= a / 2
        )
        is_internal_y = (
            np.isclose(np.abs(mass_prop[1]), a / 2, atol=1e-3)
            and np.abs(mass_prop[0]) <= a / 2
        )
        # Ensure it's not the top or bottom cap of the cuboid
        is_not_cap = not np.isclose(mass_prop[2], 0, atol=1e-3) and not np.isclose(
            mass_prop[2], H, atol=1e-3
        )

        if (is_internal_x or is_internal_y) and is_not_cap:
            interface_surfaces.append(tag)

    gmsh.model.addPhysicalGroup(3, [fem_vol[0][1]], name="FEM_Volume")
    gmsh.model.addPhysicalGroup(2, interface_surfaces, name="Interface")

    gmsh.option.setNumber("Mesh.MeshSizeMin", mesh_size)
    gmsh.option.setNumber("Mesh.MeshSizeMax", mesh_size)
    gmsh.model.mesh.generate(3)
    gmsh.write(filename)
    gmsh.finalize()

    mesh = meshio.read(filename)
    nodes = mesh.points
    fem_elements = mesh.cells_dict["tetra"]

    if "Interface" in mesh.cell_sets_dict:
        # Get the triangles forming the internal boundary
        interface_tris = mesh.cells_dict["triangle"][
            mesh.cell_sets_dict["Interface"]["triangle"]
        ]
        interface_node_ids = np.unique(interface_tris)
    else:
        interface_node_ids = np.unique(mesh.cells_dict["triangle"])

    if os.path.exists(filename):
        os.remove(filename)

    return nodes, fem_elements, interface_node_ids


def get_pyvista_grid(mesh, cell_type="quad"):
    if mesh.coords.shape[1] == 2:
        pv_points = np.hstack((mesh.coords, np.zeros(shape=(mesh.coords.shape[0], 1))))
    else:
        pv_points = np.array(mesh.coords)

    cell_type_dict = {
        "quad": 4,
        "triangle": 3,
        "tetra": 4,
        "hexahedron": 8,
    }

    pv_cells = np.hstack(
        (
            np.full(
                fill_value=cell_type_dict[cell_type], shape=(mesh.elements.shape[0], 1)
            ),
            mesh.elements,
        )
    )

    pv_cell_type_dict = {
        "quad": pv.CellType.QUAD,
        "triangle": pv.CellType.TRIANGLE,
        "tetra": pv.CellType.TETRA,
        "hexahedron": pv.CellType.HEXAHEDRON,
    }
    cell_types = np.full(
        fill_value=pv_cell_type_dict[cell_type], shape=(mesh.elements.shape[0],)
    )

    grid = pv.UnstructuredGrid(pv_cells.flatten(), cell_types, pv_points)

    return grid
nodes, elements, interface_idx = create_unstructured_3d_through_hole_mesh(
    L=10.0, H=2.0, a=a, mesh_size=0.4
)

mesh = Mesh(coords=nodes, elements=elements)

n_dofs_per_node = 3
n_nodes = mesh.coords.shape[0]
n_dofs = n_dofs_per_node * n_nodes
grid = get_pyvista_grid(mesh, cell_type="tetra")
pl = pv.Plotter()
pl.add_mesh(grid, show_edges=True)
pl.show()

FEM domain

We define a simple 3D bar of length \(L\), width \(W\), and height \(H\). The bar isfixed at one end and subjected to a force at the other end. We use Tetrahedral elements for the mesh.

tetra = element.Tetrahedron4()
op = Operator(mesh, tetra)

n_dofs_per_node = 3
n_nodes, n_dofs = mesh.coords.shape[0], mesh.coords.shape[0] * n_dofs_per_node

Defining FEM energy functional¤

from typing import NamedTuple


class Material(NamedTuple):
    """Material properties for the elasticity operator."""

    mu: float  # Diffusion coefficient
    lmbda: float  # Diffusion coefficient


E = 1e4
nu = 0.3
mu = E / 2 / (1 + nu)
lmbda = E * nu / (1 - 2 * nu) / (1 + nu)

mat = Material(mu=mu, lmbda=lmbda)


@autovmap(grad_u=2)
def compute_strain(grad_u):
    return 0.5 * (grad_u + grad_u.T)


@autovmap(eps=2, mu=0, lmbda=0)
def compute_stress(eps, mu, lmbda):
    I = jnp.eye(3)
    return 2 * mu * eps + lmbda * jnp.trace(eps) * I


@autovmap(grad_u=2, mu=0, lmbda=0)
def strain_energy(grad_u, mu, lmbda):
    eps = compute_strain(grad_u)
    sigma = compute_stress(eps, mu, lmbda)
    return 0.5 * jnp.einsum("ij,ij->", sigma, eps)


@jax.jit
def total_fem_energy(u_flat: Array) -> float:
    """Compute the total energy of the system."""
    u = u_flat.reshape(-1, n_dofs_per_node)
    u_grad = op.grad(u)
    energy_density = strain_energy(u_grad, mat.mu, mat.lmbda)
    return op.integrate(energy_density)

Defining the Neural Constitutive Model¤

The specific architecture employed for the neural strain energy density was a feed-forward Multi-Layer Perceptron (MLP). The network consisted of an input layer accepting the two scalar invariants \((I_1, J)\), followed by two hidden layers with 16 neurons each, and a final output layer producing the scalar energy value. To ensure that the second-order derivatives (Hessian) remained continuous and numerically stable, a \texttt{softplus} activation function was utilized across all hidden layers. This choice is critical as standard piecewise linear activations, such as \texttt{ReLU}, yield zero second derivatives almost everywhere, leading to immediate solver divergence.

\[ \psi_{\text{total}}(I_1, J) = \underbrace{\left[ \text{NN}(I_1, J; \theta) - \text{NN}(3, 1; \theta) \right]}_{\text{Shifted Neural Potential}} + \underbrace{\Psi_{\text{base}}(I_1, J)}_{\text{Stiffness Prior}} \]
class NeuralInclusion(eqx.Module):
    network: eqx.nn.MLP
    stiffness_prior: float  # Helps with initial convergence

    def __init__(self, n_interface_dofs, key, stiffness_prior=1e-2):
        self.stiffness_prior = stiffness_prior
        self.network = eqx.nn.MLP(
            in_size=n_interface_dofs,
            out_size="scalar",
            width_size=64,
            depth=3,
            activation=jax.nn.softplus,  # Must be smooth for Hessian
            key=key,
        )

    def __call__(self, u_interface):
        """
        Computes the shifted energy: G(u) - G(0) + prior
        u_interface: flattened array of displacements for nodes on the boundary
        """
        psi_raw = self.network(u_interface)

        u_zero = jnp.zeros_like(u_interface)
        psi_0 = self.network(u_zero)

        prior = 0.5 * self.stiffness_prior * jnp.sum(u_interface**2)

        return (psi_raw - psi_0) + prior


neural_operator = NeuralInclusion(
    n_interface_dofs=len(interface_idx) * n_dofs_per_node,
    key=jax.random.PRNGKey(0),
    stiffness_prior=1e4, #1e-2
)

Coupling the domains through energies¤

Now, we define the neural network architecture and the total strain energy density function based on the neural network defined above.

def total_energy(u_flat, neural_operator):
    u = u_flat.reshape(-1, n_dofs_per_node)
    energy_fem = total_fem_energy(u_flat)

    # Extract displacements for interface nodes
    u_interface = u[interface_idx].flatten()
    energy_neural = neural_operator(u_interface)

    return energy_fem + energy_neural

Using Coloring to compute Sparse Hessians¤

sparsity_pattern = sparse.create_sparsity_pattern(mesh, n_dofs_per_node=n_dofs_per_node)
sparsity_pattern_csr = sp.csr_matrix(
    (
        sparsity_pattern.data,
        (sparsity_pattern.indices[:, 0], sparsity_pattern.indices[:, 1]),
    )
)
indptr = sparsity_pattern_csr.indptr
indices = sparsity_pattern_csr.indices
colors = distance2_color_and_seeds(
    row_ptr=sparsity_pattern_csr.indptr,
    col_idx=sparsity_pattern_csr.indices,
    n_dofs=n_dofs,
)[0]

# Closure for the energy based on the NN weights
energy_fn = eqx.Partial(total_energy, neural_operator=neural_operator)
gradient_fn = jax.jacrev(energy_fn)

K_sparse_fn = sparse.jacfwd_with_batch(
    gradient=gradient_fn,
    row_ptr=jnp.array(sparsity_pattern_csr.indptr),
    col_indices=jnp.array(sparsity_pattern_csr.indices),
    colors=jnp.array(colors),
    color_batch_size=mesh.elements.shape[0],
)

To check if the total energy at 0 deformation is zero, we can evaluate the total strain energy density function at the reference configuration where \(I_1 = 3\) and \(J = 1\). This ensures that the neural network's contribution is shifted appropriately, and the stiffness prior is also evaluated at this point.

Applying Boundary Conditions and Loads¤

# Boundary Conditions & Solver Setup
y_min, y_max = jnp.min(mesh.coords[:, 1]), jnp.max(mesh.coords[:, 1])


top_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], y_max))[0]
bottom_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], y_min))[0]

applied_dofs = n_dofs_per_node * top_nodes + 1  # y-direction DOFs at the top nodes

zero_dofs = jnp.concatenate(
    [n_dofs_per_node * bottom_nodes, n_dofs_per_node * bottom_nodes + 1]
)

fixed_dofs = jnp.concatenate([applied_dofs, zero_dofs])

prescribed_values = jnp.zeros(n_dofs).at[applied_dofs].set(1.0)

zero_indices, one_indices = sparse.get_bc_indices(sparsity_pattern, fixed_dofs)

Defining Newton Solver¤

Newton Solver with Sparse Hessian
@eqx.filter_jit
def newton_sparse_solver(
    u,
    fext,
    gradient,
    hessian_sparse,
    fixed_dofs,
    zero_indices,
    one_indices,
    indptr,
    indices,
):
    fint = gradient(u)

    norm_res = 1.0

    tol = 1e-8
    max_iter = 10

    def solver(u, n):
        def true_func(u):
            fint = gradient(u)
            residual = fext - fint
            residual = residual.at[fixed_dofs].set(0.0)

            K_sparse = hessian_sparse(u)
            K_data_lifted = K_sparse.data.at[zero_indices].set(0)
            K_data_lifted = K_data_lifted.at[one_indices].set(1)

            du = jsp.linalg.spsolve(
                K_data_lifted, indices=indices, indptr=indptr, b=residual
            )

            u = u.at[:].add(du)
            return u

        def false_func(u):
            return u

        fint = gradient(u)
        residual = fext - fint
        residual = residual.at[fixed_dofs].set(0.0)
        norm_res = jnp.linalg.norm(residual)

        jax.debug.print("residual={}", norm_res)

        return jax.lax.cond(norm_res > tol, true_func, false_func, u), n

    u, xs = jax.lax.scan(solver, init=u, xs=jnp.arange(0, max_iter))

    fint = gradient(u)
    residual = fext - fint
    residual = residual.at[fixed_dofs].set(0.0)
    norm_res = jnp.linalg.norm(residual)

    return u, norm_res

Solving the System¤

fext = jnp.zeros(n_dofs) 

n_steps = 5
applied_displacement = prescribed_values / n_steps  # displacement increment

residual_history = []

print("Starting Neural Constitutive Solver...")
for i in range(n_steps):  # Newton iterations
    u_prev = u_prev.at[fixed_dofs].add(applied_displacement[fixed_dofs])

    u_new, rnorm = newton_sparse_solver(
        u_prev,
        fext,
        gradient_fn,
        K_sparse_fn,
        fixed_dofs,
        zero_indices,
        one_indices,
        indptr,
        indices,
    )

    residual_history.append(rnorm)

    u_prev = u_new

    print(f"Iteration {i}: Residual Norm = {rnorm:.4e}")

u_sol = u_prev.reshape(n_nodes, n_dofs_per_node)

Visualization of Results
grid = pv.UnstructuredGrid(
    np.hstack((np.full((mesh.elements.shape[0], 1), 4), mesh.elements)).flatten(),
    np.full(mesh.elements.shape[0], pv.CellType.TETRA),
    np.array(mesh.coords),
)

pl = pv.Plotter()

grad_u = op.grad(u_sol).squeeze()
strains = compute_strain(grad_u)
stresses = compute_stress(strains, mat.mu, mat.lmbda)


grid["u"] = np.array(u_sol)
grid["sigma_yy"] = stresses[:, 1, 1].flatten()

warped = grid.warp_by_vector("u", factor=4.0)
pl.add_mesh(warped, show_edges=False, scalars="u", component=0, cmap="managua", show_scalar_bar=False)
pl.view_isometric()
pl.screenshot("../assets/plots/neural_soft_inclusion_deformed_mesh.png", transparent_background=True)
pl.show()

Neural operator element with soft inclusion