Skip to content

3D Multi-phase Material¤

import jax

jax.config.update("jax_enable_x64", True)  # use double-precision
jax.config.update("jax_platforms", "cpu")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

import jax.numpy as jnp
from jax import Array
import numpy as np

In this example, we solve a linear elasticity problem of a multiphase material in 3D.

from xpektra import (
    SpectralSpace,
    TensorOperator,
    make_field,
)
from xpektra.scheme import RotatedDifference, Fourier
from xpektra.projection_operator import GalerkinProjection
from xpektra.solvers.nonlinear import (  # noqa: E402
    conjugate_gradient_while,
    newton_krylov_solver,
)

import equinox as eqx
from scipy.spatial.distance import cdist


def generate_multiphase_material_3d(
    size, num_phases=5, num_seeds=None, random_state=None
):
    """
    Generate a 3D multiphase material using Voronoi tessellation.

    Parameters:
    -----------
    size : int or tuple of 3 ints
        Size of the 3D array. If int, creates a cubic array of size (size, size, size).
        If tuple, specifies (nx, ny, nz) dimensions.
    num_phases : int, default=5
        Number of different phases (materials) in the microstructure.
    num_seeds : int, optional
        Number of Voronoi seed points. If None, defaults to num_phases * 3.
    random_state : int, optional
        Random seed for reproducibility.

    Returns:
    --------
    material : numpy.ndarray
        3D array with values from 0 to 1, where each phase is represented by
        values in the range [i/num_phases, (i+1)/num_phases) for phase i.
    seed_points : numpy.ndarray
        Array of seed points used for Voronoi tessellation.
    """

    if random_state is not None:
        np.random.seed(random_state)

    # Handle size parameter
    if isinstance(size, int):
        nx, ny, nz = size, size, size
    else:
        nx, ny, nz = size

    # Default number of seeds
    if num_seeds is None:
        num_seeds = num_phases * 3

    # Generate random seed points in 3D space
    seed_points = np.random.rand(num_seeds, 3)
    seed_points[:, 0] *= nx
    seed_points[:, 1] *= ny
    seed_points[:, 2] *= nz

    # Assign each seed to a phase
    seed_phases = np.random.randint(0, num_phases, num_seeds)

    # Create coordinate grids
    x, y, z = np.meshgrid(np.arange(nx), np.arange(ny), np.arange(nz), indexing="ij")
    grid_points = np.column_stack([x.ravel(), y.ravel(), z.ravel()])

    # Calculate distances from each grid point to all seed points
    distances = cdist(grid_points, seed_points)

    # Find closest seed for each grid point
    closest_seeds = np.argmin(distances, axis=1)

    # Assign phases based on closest seed
    phase_array = seed_phases[closest_seeds].reshape(nx, ny, nz)

    # Convert to values between 0 and 1
    material = phase_array.astype(float) / num_phases + np.random.rand(nx, ny, nz) / (
        num_phases * 10
    )
    material = np.clip(material, 0, 1)

    return material, seed_points

We use the generate_multiphase_material_3d function to generate a 3D multiphase material. The function returns the material and the seed points.

N = 31
ndim = 3
length = 1

structure, seeds = generate_multiphase_material_3d(
    size=N, num_phases=5, num_seeds=15, random_state=42
)

To simplify the execution of the code, we define a ElasticityOperator class that contains the Fourier-Galerkin operator, the spatial operators, the tensor operators, and the FFT and IFFT operators. The __init__ method initializes the operator and the __call__ method computes the stresses in the real space given as

$$ \mathcal{F}^{-1} \left( \mathbb{G}:\mathcal{F}(\sigma) \right)

$$

We define the grid size and the length of the RVE and construct the structure of the RVE.

tensor = TensorOperator(dim=ndim)
space = SpectralSpace(size=N, dim=ndim, length=length)

Next, we define the material parameters.

λ0 = structure.copy()
μ0 = structure.copy()
Ghat = GalerkinProjection(scheme=RotatedDifference(space=space)).compute_operator()


dofs_shape = make_field(dim=ndim, N=N, rank=2).shape

The linear elasticity strain energy is given as $$ W = \frac{1}{2} \int_{\Omega} (\lambda \text{tr}(\epsilon)^2+ \mu \text{tr}(\epsilon : \epsilon ) ) d\Omega $$

We define a python function to compute the strain energy and then use the jax.jacrev function to compute the stress tensor.

@eqx.filter_jit
def strain_energy(eps_flat: Array) -> float:
    eps = eps_flat.reshape(dofs_shape)
    eps_sym = 0.5 * (eps + tensor.trans(eps))
    energy = 0.5 * jnp.multiply(λ0, tensor.trace(eps_sym) ** 2) + jnp.multiply(
        μ0, tensor.trace(tensor.dot(eps_sym, eps_sym))
    )
    return energy.sum()


compute_stress = jax.jacrev(strain_energy)
class Residual(eqx.Module):
    """A callable module that computes the residual vector."""

    Ghat: Array
    space: SpectralSpace = eqx.field(static=True)
    tensor_op: TensorOperator = eqx.field(static=True)
    dofs_shape: tuple = eqx.field(static=True)

    # We can even pre-define the stress function if it's always the same
    # For this example, we'll keep your original `compute_stress` function
    # available in the global scope.

    @eqx.filter_jit
    def __call__(self, eps_flat: Array) -> Array:
        """
        This makes instances of this class behave like a function.
        It takes only the flattened vector of unknowns, as required by the solver.
        """
        # eps = eps_flat.reshape(self.dofs_shape)
        sigma = compute_stress(eps_flat)  # Assumes compute_stress is defined elsewhere

        residual_field = self.space.ifft(
            self.tensor_op.ddot(
                self.Ghat, self.space.fft(sigma.reshape(self.dofs_shape))
            )
        )
        return jnp.real(residual_field).reshape(-1)


class Jacobian(eqx.Module):
    """A callable module that represents the Jacobian operator (tangent)."""

    Ghat: Array
    space: SpectralSpace = eqx.field(static=True)
    tensor_op: TensorOperator = eqx.field(static=True)
    dofs_shape: tuple = eqx.field(static=True)

    @eqx.filter_jit
    def __call__(self, deps_flat: Array) -> Array:
        """
        The Jacobian is a linear operator, so its __call__ method
        represents the Jacobian-vector product.
        """
        # deps = deps_flat.reshape(self.dofs_shape)
        # Assuming linear elasticity, the tangent is the same as the residual operator
        dsigma = compute_stress(deps_flat)

        jvp_field = self.space.ifft(
            self.tensor_op.ddot(
                self.Ghat, self.space.fft(dsigma.reshape(self.dofs_shape))
            )
        )
        return jnp.real(jvp_field).reshape(-1)
eps = make_field(dim=ndim, N=N, rank=2)
residual_fn = Residual(Ghat=Ghat, space=space, tensor_op=tensor, dofs_shape=eps.shape)
jacobian_fn = Jacobian(Ghat=Ghat, space=space, tensor_op=tensor, dofs_shape=eps.shape)
deps = make_field(dim=ndim, N=N, rank=2)

applied_strains = jnp.diff(jnp.linspace(0, 2e-2, num=2))


for inc, deps_avg in enumerate(applied_strains):
    # solving for elasticity
    deps[:, :, 0, 0] = deps_avg
    b = -residual_fn(deps)
    eps = eps + deps

    final_state = newton_krylov_solver(
        state=(deps, b, eps),
        gradient=residual_fn,
        jacobian=jacobian_fn,
        tol=1e-6,
        max_iter=20,
        krylov_solver=conjugate_gradient_while,
        krylov_tol=1e-6,
        krylov_max_iter=20,
    )
    eps = final_state[2]

    print("step", inc, "time", inc)

sig = compute_stress(eps).reshape(dofs_shape)
CG error = 0.22234152087832
CG error = 0.00000045775788
CG error = 0.00000001828671
CG error = 0.00000000343622
CG error = 0.00000000036042
CG error = 0.00000000009831
CG error = 0.00000000001393
CG error = 0.00000000000470
Converged, Residual value : 9.766911206116589e-07
step 0 time 0

Lets us now plot the stress tensor along the x-plane.

import matplotlib.pyplot as plt  # noqa: E402
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 4))

cax1 = ax1.imshow(sig.at[:, :, int(N / 2), 0, 0].get(), cmap="managua_r")

divider = make_axes_locatable(ax1)
cax = divider.append_axes("top", size="10%", pad=0.2)
fig.colorbar(cax1, cax=cax, label=r"$\sigma_{xx}$", orientation="horizontal", location="top")


cax2 = ax2.imshow(sig.at[:, :, int(N / 2), 0, 1].get(), cmap="managua_r")
divider = make_axes_locatable(ax2)
cax = divider.append_axes("top", size="10%", pad=0.2)
fig.colorbar(cax2, cax=cax, label=r"$\sigma_{xy}$", orientation="horizontal", location="top")

cax3 = ax3.imshow(sig.at[:, :, int(N / 2), 1, 1].get(), cmap="managua_r")
divider = make_axes_locatable(ax3)
cax = divider.append_axes("top", size="10%", pad=0.2)
fig.colorbar(cax3, cax=cax, label=r"$\sigma_{yy}$", orientation="horizontal", location="top")

plt.show()

img