Lifter¤
Problem¤
In constrained problems, we often solve only for free DOFs and reconstruct the full vector afterwards. Typical constraints include:
- Dirichlet conditions (fixed values)
- Periodic constraints (slave DOFs copied from master DOFs)
Without a helper, we manually manage index sets and repeat lifting logic across the codebase.
A manual implementation usually looks like this:
u_full = u_full.at[free_dofs].set(u_reduced)
u_full = u_full.at[dirichlet_dofs].set(dirichlet_values)
u_full = u_full.at[periodic_dofs].set(u_full[periodic_master_dofs])
This is easy to get wrong, especially when constraint sets evolve.
Solution¤
Use Lifter to define constraints once and map consistently between:
- reduced vectors (free DOFs only)
- full vectors (all DOFs)
import jax
import jax.numpy as jnp
from tatva.lifter import DirichletBC, Lifter, PeriodicMap
n_dofs = 8
Define constraints¤
Here we fix two boundary DOFs and tie two periodic slave DOFs to masters.
dirichlet = DirichletBC(
dofs=jnp.array([0, 7]),
values=jnp.array([0.0, 0.0]), # default is zeros, so this line is optional
)
periodic = PeriodicMap(
dofs=jnp.array([4, 6]),
master_dofs=jnp.array([1, 3]),
)
lifter = Lifter(n_dofs, dirichlet, periodic)
Lifter computes constrained/free sets and the reduced system size.
print("constrained dofs:", lifter.constrained_dofs)
print("free dofs:", lifter.free_dofs)
print("size:", lifter.size)
print("size_reduced:", lifter.size_reduced)
constrained dofs: [0 4 6 7]
free dofs: [1 2 3 5]
size: 8
size_reduced: 4
Lifting (free -> full)¤
Use lift_from_zeros when you want the full vector reconstructed from a zero base state.
u_reduced = jnp.array([10.0, 20.0, 30.0, 40.0])
u_full = lifter.lift_from_zeros(u_reduced)
print("u_reduced:", u_reduced)
print("u_full:", u_full)
u_reduced: [10. 20. 30. 40.]
u_full: [ 0. 10. 20. 30. 10. 40. 30. 0.]
Use lift when you want to start from a previous iterate and only overwrite free DOFs before applying constraints.
u_prev = jnp.linspace(-1.0, 1.0, n_dofs)
u_full_from_prev = lifter.lift(u_reduced, u_prev)
print("previous full state:", u_prev)
print("lifted full state:", u_full_from_prev)
previous full state: [-1. -0.71428573 -0.42857143 -0.14285709 0.1428572 0.42857146
0.71428585 1. ]
lifted full state: [ 0. 10. 20. 30. 10. 40. 30. 0.]
Usage in total energy function
def total_energy(z_full: Array) -> Array:
(u,) = Solution(z_full) # See the guide for 'Compound'
E = ... # compute energy from u
return E
def total_energy_free(z_free: Array) -> Array:
z_full = lifter.lift_from_zeros(z_free)
return total_energy(z_full)
residual_fn = jax.jacrev(total_energy_free)
or do it inline:
residual_fn = jax.jacrev(lambda z_free: total_energy(lifter.lift_from_zeros(z_free)))
Reducing (full -> free)¤
reduce extracts the reduced vector from any full vector.
u_reduced_back = lifter.reduce(u_full)
print("reduced from full:", u_reduced_back)
reduced from full: [10. 20. 30. 40.]
Note
The round-trip reduce(lift_from_zeros(u_reduced)) returns u_reduced.
Build constraints incrementally¤
You can append constraints with add, which returns a new Lifter.
base_lifter = Lifter(n_dofs)
constrained_lifter = base_lifter.add(dirichlet).add(periodic)
print("base size_reduced:", base_lifter.size_reduced)
print("constrained size_reduced:", constrained_lifter.size_reduced)
JAX compatibility¤
Lifter can be used inside JAX-transformed functions.
def reduced_energy(u_r: jax.Array) -> jax.Array:
u_f = lifter.lift_from_zeros(u_r)
return jnp.sum(u_f**2)
energy_jit = jax.jit(reduced_energy)
energy_grad = jax.grad(reduced_energy)
print("energy:", energy_jit(u_reduced))
print("gradient:", energy_grad(u_reduced))
Summary¤
Use Lifter when solving constrained systems in reduced coordinates.
- centralizes DOF bookkeeping for constraints
- provides explicit
lift/reduceoperations - keeps reduced-space code clean and JAX-friendly