Skip to content

Open In Colab

Lifter¤

Problem¤

In constrained problems, we often solve only for free DOFs and reconstruct the full vector afterwards. Typical constraints include:

  • Dirichlet conditions (fixed values)
  • Periodic constraints (slave DOFs copied from master DOFs)

Without a helper, we manually manage index sets and repeat lifting logic across the codebase.

A manual implementation usually looks like this:

u_full = u_full.at[free_dofs].set(u_reduced)
u_full = u_full.at[dirichlet_dofs].set(dirichlet_values)
u_full = u_full.at[periodic_dofs].set(u_full[periodic_master_dofs])

This is easy to get wrong, especially when constraint sets evolve.

Solution¤

Use Lifter to define constraints once and map consistently between:

  • reduced vectors (free DOFs only)
  • full vectors (all DOFs)
import jax
import jax.numpy as jnp
from tatva.lifter import DirichletBC, Lifter, PeriodicMap

n_dofs = 8

Define constraints¤

Here we fix two boundary DOFs and tie two periodic slave DOFs to masters.

dirichlet = DirichletBC(
    dofs=jnp.array([0, 7]),
    values=jnp.array([0.0, 0.0]),  # default is zeros, so this line is optional
)
periodic = PeriodicMap(
    dofs=jnp.array([4, 6]),
    master_dofs=jnp.array([1, 3]),
)

lifter = Lifter(n_dofs, dirichlet, periodic)

Lifter computes constrained/free sets and the reduced system size.

print("constrained dofs:", lifter.constrained_dofs)
print("free dofs:", lifter.free_dofs)
print("size:", lifter.size)
print("size_reduced:", lifter.size_reduced)
constrained dofs: [0 4 6 7]
free dofs: [1 2 3 5]
size: 8
size_reduced: 4

Lifting (free -> full)¤

Use lift_from_zeros when you want the full vector reconstructed from a zero base state.

u_reduced = jnp.array([10.0, 20.0, 30.0, 40.0])
u_full = lifter.lift_from_zeros(u_reduced)

print("u_reduced:", u_reduced)
print("u_full:", u_full)
u_reduced: [10. 20. 30. 40.]
u_full: [ 0. 10. 20. 30. 10. 40. 30.  0.]

Use lift when you want to start from a previous iterate and only overwrite free DOFs before applying constraints.

u_prev = jnp.linspace(-1.0, 1.0, n_dofs)
u_full_from_prev = lifter.lift(u_reduced, u_prev)

print("previous full state:", u_prev)
print("lifted full state:", u_full_from_prev)
previous full state: [-1.         -0.71428573 -0.42857143 -0.14285709  0.1428572   0.42857146
  0.71428585  1.        ]
lifted full state: [ 0. 10. 20. 30. 10. 40. 30.  0.]

Usage in total energy function

def total_energy(z_full: Array) -> Array:
    (u,) = Solution(z_full)  # See the guide for 'Compound'
    E = ...  # compute energy from u
    return E

def total_energy_free(z_free: Array) -> Array:
    z_full = lifter.lift_from_zeros(z_free)
    return total_energy(z_full)

residual_fn = jax.jacrev(total_energy_free)

or do it inline:

residual_fn = jax.jacrev(lambda z_free: total_energy(lifter.lift_from_zeros(z_free)))

Reducing (full -> free)¤

reduce extracts the reduced vector from any full vector.

u_reduced_back = lifter.reduce(u_full)
print("reduced from full:", u_reduced_back)
reduced from full: [10. 20. 30. 40.]

Note

The round-trip reduce(lift_from_zeros(u_reduced)) returns u_reduced.

Build constraints incrementally¤

You can append constraints with add, which returns a new Lifter.

base_lifter = Lifter(n_dofs)
constrained_lifter = base_lifter.add(dirichlet).add(periodic)

print("base size_reduced:", base_lifter.size_reduced)
print("constrained size_reduced:", constrained_lifter.size_reduced)

JAX compatibility¤

Lifter can be used inside JAX-transformed functions.

def reduced_energy(u_r: jax.Array) -> jax.Array:
    u_f = lifter.lift_from_zeros(u_r)
    return jnp.sum(u_f**2)


energy_jit = jax.jit(reduced_energy)
energy_grad = jax.grad(reduced_energy)

print("energy:", energy_jit(u_reduced))
print("gradient:", energy_grad(u_reduced))

Summary¤

Use Lifter when solving constrained systems in reduced coordinates.

  • centralizes DOF bookkeeping for constraints
  • provides explicit lift / reduce operations
  • keeps reduced-space code clean and JAX-friendly