Sparse Differentiation¤
Problem¤
In finite element analysis, the tangent stiffness matrix (Hessian) \(\mathbf{K}\) is typically sparse. A node only interacts with its immediate neighbors, meaning most entries in \(\mathbf{K}\) are zero. However, standard automatic differentiation (AD) in JAX (jax.jacfwd or jax.jacrev) is unaware of this sparsity. It attempts to recover the full dense matrix by evaluating the Jacobian-Vector Product (JVP) once for every degree of freedom.
For a mesh with \(N\) degrees of freedom:
- Naive AD Cost: \(N \times t_{\text{residual}}\) (Prohibitive for large \(N\))
- Memory: \(O(N^2)\) (Explodes quickly)
We need a way to compute only the non-zero entries of \(\mathbf{K}\) efficiently, ideally in constant time with respect to the mesh size.
Solution¤
tatva.sparse provides a sparse differentiation engine that reduces the cost from \(O(N)\) to \(O(c)\), where \(c\) is the "chromatic number" of the mesh (typically small and constant, e.g., ~10-20 for 2D meshes).
The process has three steps:
- Sparsity Pattern: Identify the non-zero structure.
- Coloring: Group non-interacting DOFs.
- Differentiation: Compute the matrix in batches.
import jax
jax.config.update("jax_enable_x64", True) # use double-precision
import jax.numpy as jnp
import scipy.sparse as sp
from tatva import sparse, Mesh
mesh = Mesh.unit_square(n_x=1, n_y=1, type="triangle", dim=2)
n_dofs_per_node = 2
n_dofs = mesh.coords.shape[0] * n_dofs_per_node
Sparsity Pattern¤
First, we analyze the mesh connectivity to determine which DOFs interact. create_sparsity_pattern returns the indices of the non-zero entries.
# Extract sparsity topology from the mesh
sparsity_pattern = sparse.create_sparsity_pattern(
mesh,
n_dofs_per_node=n_dofs_per_node
)
# Convert to Scipy CSR format for efficient indexing and later use
sparsity_pattern_csr = sp.csr_matrix(
(
sparsity_pattern.data,
(sparsity_pattern.indices[:, 0], sparsity_pattern.indices[:, 1]),
)
)
print(f"Sparsity: {sparsity_pattern_csr.nnz} non-zeros")
Sparsity: 56 non-zeros
Sparsity Pattern
In tatva we provide a few functionalities to generate sparsity patetrn for some specific problem.
- for a single physical field problem,
sparse.create_sparsity_pattern - for KKT problems,
sparse.create_sparsity_pattern_KKT - for reduced system via condensation
sparse.reduce_sparsity_pattern
Graph Coloring¤
We partition the degrees of freedom into independent sets (colors). Two DOFs share the same color only if they do not share an edge in the sparsity graph (Distance-1) and do not share a common neighbor (Distance-2). This ensures that when we perturb all DOFs of "Color A" simultaneously, their contributions to the Hessian do not overlap.
Coloring Algorithm
We use tatva_coloring library to generate colors from a sparsity pattern and take the first return value which is the colors. The implemented coloring algorithm is a naive greedy-algorithm. One can easily use other coloring libraries such as !pysparsematrixcolorings to use advanced coloring algorithms which are efficient as they generate less number of colors.
from tatva_coloring import distance2_color_and_seeds
colors = distance2_color_and_seeds(
row_ptr=sparsity_pattern_csr.indptr,
col_idx=sparsity_pattern_csr.indices,
n_dofs=n_dofs,
)[0]
print(f"Number of colors required: {jnp.max(colors) + 1}")
Number of colors required: 8
Sparse Differentiation¤
Finally, we use jacfwd_with_batch. This function automatically:
- Perturbs the input \(\boldsymbol{u}\) using the color groups.
- Evaluates the gradient efficiently (Batched JVPs).
- Reconstructs the values into the correct sparse matrix locations.
Info
We use our own implementation of sparse differentiation to make it scalable for large problems. But one can use libraries such as !sparsejac which has been an source of inspiration for our own implementation.
def energy_fn(u, delta):
# Placeholder for the actual energy function
return jnp.sum(u**2) + delta * jnp.sum(u)
gradient_fn = jax.jacrev(energy_fn)
# differentiate the residual using the sparsity information
K_sparse_fn = sparse.jacfwd_with_batch(
gradient=gradient_fn,
row_ptr=jnp.array(sparsity_pattern_csr.indptr),
col_indices=jnp.array(sparsity_pattern_csr.indices),
colors=jnp.array(colors),
color_batch_size=mesh.elements.shape[0], # Batch size for evaluating the element routine
)
u_current = jnp.zeros(n_dofs) # Example input
delta_current = 1.0 # Example parameter
K_sparse = K_sparse_fn(u_current, delta_current)
Info
The above function K_sparse_fn needs to be created only once given the sparsity pattern is not changing. Once created one can use te generated function within the simulation loop.
If the energy function takes additional arguments for example history parameters then the generated K_sparse_fn also takes the same arguments.
Tip
Sometimes we need to evalues \(\mathbf{K}\) at fixed values for additinal arguments but updated values of \(\boldsymbol{u}\). For example, in Newton-Raphson or Staggered solvers. Use equinox.Partial from equinox to freeze some argument values. For example
import equinox as eqx
K_sparse_partial = eqx.Partial(K_sparse_fn, delta=delta_current)
# newton iteration
for i in range(iter):
...
# call K_sparse_partial with just u
K_sparse = K_sparse_partial(u_new)
...
Matrix-Free Operators¤
For extremely large problems, even storing the sparse matrix indices might be too memory-intensive. In these cases, we can use Matrix-Free methods. Since we have the energy functional, we can compute the Jacobian-vector product (JVP) \(\mathbf{K}\mathbf{v}\) directly without ever forming \(\mathbf{K}\).
def jacobian_vector_product(u, v):
"""
Computes (Hessian of Energy at u) * v. It is equivalent to: jvp( jacrev(energy)(u), v ).
Args:
u: Current solution (shape: [n_dofs])
v: Vector to multiply with the Hessian (shape: [n_dofs])
Returns:
The product of the Hessian of the energy function at u with the vector v (shape: [n_dofs])
"""
return jax.jvp(jax.jacrev(energy_fn), (u,), (v,))[1]
# This can be passed directly to iterative solvers like jax.scipy.sparse.linalg.cg