Applying Constraints or BCs¤
In constrained problems, we often solve only for free DOFs and reconstruct the full vector afterwards. Typical constraints include:
- Dirichlet conditions (fixed values)
- Periodic constraints (slave DOFs copied from master DOFs)
Since tatva uses jax.numpy.array, we can manually manage index sets and repeat lifting logic across the codebase.
A manual implementation usually looks like this:
u_full = u_full.at[free_dofs].set(u_reduced)
u_full = u_full.at[dirichlet_dofs].set(dirichlet_values)
u_full = u_full.at[periodic_dofs].set(u_full[periodic_master_dofs])
Problem
The manual implementation of index sets is easy to get wrong, especially when constraint sets evolve. Therefore, we provide a utility Lifter to ease the managing of indexes.
Getting started with Lifter¤
Use Lifter to define constraints once and map consistently between:
- reduced vectors (free DOFs only)
- full vectors (all DOFs)
A lifter takes a size, and then any number of Constraints. The library currently
includes two constraint types: Fixed and Periodic, but you are encouraged to add your own constraint types.
Colab Setup (Install Dependencies)
# Only run this if we are in Google Colab
if "google.colab" in str(get_ipython()):
print("Installing dependencies from pyproject.toml...")
# This installs the repo itself (and its dependencies)
!apt-get install gmsh
!apt-get install -qq xvfb libgl1-mesa-glx
!pip install pyvista -qq
!pip install -q "git+https://github.com/smec-ethz/tatva-docs.git"
print("Installation complete!")
import jax
import jax.numpy as jnp
from tatva.lifter import Fixed, Lifter, Periodic
n_dofs = 8
dirichlet = Fixed(
dofs=jnp.array([0, 7]),
values=jnp.array([0.0, 0.0]), # default is zeros, so this line is optional
)
periodic = Periodic(
dofs=jnp.array([4, 6]),
master_dofs=jnp.array([1, 3]),
)
lifter = Lifter(n_dofs, dirichlet, periodic)
Lifting (free -> full)¤
Use lift_from_zeros when you want the full vector reconstructed from a zero base state.
u_reduced = jnp.array([10.0, 20.0, 30.0, 40.0])
u_full = lifter.lift_from_zeros(u_reduced)
print("u_reduced:", u_reduced)
print("u_full:", u_full)
u_reduced: [10. 20. 30. 40.]
u_full: [ 0. 10. 20. 30. 10. 40. 30. 0.]
Use lift when you want to start from a previous iterate and only overwrite free DOFs before applying constraints.
u_prev = jnp.linspace(-1.0, 1.0, n_dofs)
u_full_from_prev = lifter.lift(u_reduced, u_prev)
print("previous full state:", u_prev)
print("lifted full state:", u_full_from_prev)
previous full state: [-1. -0.71428573 -0.42857143 -0.14285709 0.1428572 0.42857146
0.71428585 1. ]
lifted full state: [ 0. 10. 20. 30. 10. 40. 30. 0.]
Usage in total energy function
In the most general case, consider lifter as a static object (don't pass it as an argument to functions) and use it to lift the solution.
def total_energy(z_full: Array) -> Array:
(u,) = Solution(z_full) # See the guide for 'Compound'
E = ... # compute energy from u
return E
def total_energy_free(z_free: Array) -> Array:
z_full = lifter.lift_from_zeros(z_free)
return total_energy(z_full)
residual_fn = jax.jacrev(total_energy_free)
or do it inline:
residual_fn = jax.jacrev(lambda z_free: total_energy(lifter.lift_from_zeros(z_free)))
Reducing (full -> free)¤
reduce extracts the reduced vector from any full vector.
u_reduced_back = lifter.reduce(u_full)
print("reduced from full:", u_reduced_back)
reduced from full: [10. 20. 30. 40.]
Note
The round-trip reduce(lift_from_zeros(u_reduced)) returns u_reduced.
Constraints¤
Some constraints are static (like the ones above).
Other constraints are dynamic, \(e.g.\) a Dirichlet BC to apply a displacement controlled loading.
To offer this, we have RuntimeValue.
Fixed supports RuntimeValue out of the box. You can add RuntimeValue to your own constraints (see Custom constraints)
Each RuntimeValue has a key (and optionally a default value). These values can be changed after with:
lifter = lifter.at['top'].set(...) # to change one
lifter = lifter.with_values({"top": ..., "other": ...}) # to set many at once
Fixed¤
This one is to fix DOFs to some value. Often zero. The signature is Fixed(dofs, values).
You can pass a RuntimeValue as a value to make it dynamic!
from tatva.lifter import RuntimeValue
dirichlet = Fixed(jnp.array([0, 7]))
dirichlet_load = Fixed(jnp.array([5]), RuntimeValue("top")) # we displace "top"
lifter = Lifter(n_dofs, dirichlet, dirichlet_load)
And then change it with:
lifter = lifter.at["top"].set(0.1) # Set the value of the "top" load to 0.1
lifter.lift_from_zeros(
jnp.zeros(lifter.size_reduced)
) # see that the dof 5 is set to 0.1 due to the dirichlet load
Array([0. , 0. , 0. , 0. , 0. , 0.1, 0. , 0. ], dtype=float32)
Periodic¤
This one is to define a periodic map between dofs. The signature is Periodic(slave_dofs, master_dofs).
periodic = Periodic(jnp.array([4, 6]), jnp.array([1, 3]))
lifter = Lifter(n_dofs, periodic)
u_free = jnp.zeros(lifter.size_reduced)
u_free = u_free.at[1].set(1.0).at[3].set(1.0) # Set the master dofs to 1.0
lifter.lift_from_zeros(u_free) # See that the periodic dofs are also set to 1.0
Array([0., 1., 0., 1., 1., 0., 1., 0.], dtype=float32)
Custom constraints¤
To create a custom constraint, inherit from lifter.Constraint.
In your constraint, you need to set the attribute dofs, and define the apply_lift method:
- the constraint must set
self.dofs(=the constrained dofs) apply_lift(self, u_full: Array) -> Array
This method should set the constrained dofs in the solution array.
Example: Rigid body MPC
This constraint will set the constrained dofs of the nodes in disk_nodes based on the disk center dofs.
from tatva.mesh import Mesh
from tatva.compound import Compound, field
from tatva.lifter import Constraint, RuntimeValue
from jax import Array
mesh: Mesh
n_nodes = 10 # Example number of nodes in the mesh
class Solution(Compound):
u = field(shape=(n_nodes, 2)) # Displacement field for all nodes
u_disk = field(shape=(1, 3)) # (ux_c, uy_c, theta_c) for the disk
class RigidBody(Constraint):
def __init__(
self, disk_center: Array, disk_nodes: Array, ux_disk: float | RuntimeValue
):
# We want to be able to set the disks center x-disp at runtime, so we use a RuntimeValue for that
self.ux_disk = ux_disk
self.disk_nodes = disk_nodes
self.disk_center = disk_center
self.dofs = jnp.concatenate(
[
Solution.u[disk_nodes, :].flatten(),
# dirichlet bc for center x and y displacements:
Solution.u_disk[0, :2].flatten(),
]
)
def apply_lift(self, u_full: Array) -> Array:
"""Apply the rigid body constraint by setting the displacements at the disk nodes
based on the center displacements and rotation.
"""
sol = Solution(u_full)
(u, u_disk) = sol
_, uy_c, theta_c = u_disk[0]
# use _resolve_runtime to get the value of ux_c at runtime:
ux_c = self._resolve_runtime(self.ux_disk)
# Set the center x-displacement
sol.u_disk = u_disk.at[0, 0].set(ux_c)
# Compute the rigid body displacements for the disk nodes based on the center
# displacements and rotation
vecs = mesh.coords[self.disk_nodes] - self.disk_center
cos, sin = jnp.cos(theta_c), jnp.sin(theta_c)
rotation_matrix = jnp.array([[cos, -sin], [sin, cos]])
rotated_vecs = vecs @ rotation_matrix.T
u_rotation = rotated_vecs - vecs
u_rigid = jnp.stack([ux_c, uy_c]) + u_rotation
# Set the displacements at the disk nodes
sol.u = u.at[self.disk_nodes, :].set(u_rigid)
return sol.arr
As you can see, we have made self.ux_disk possibly a RuntimeValue. With this setup, you will be able to change the disk position directly from the lifter itself.
lifter = Lifter(
Solution.size,
RigidBody(
disk_center=jnp.array([0.5, 0.5]),
disk_nodes=jnp.array([2, 3, 4]),
ux_disk=RuntimeValue("ux_disk", default=0.0),
),
)
lifter = lifter.at["ux_disk"].set(0.1) # Set the center x-displacement to 0.1
JAX compatibility¤
Lifter can be used inside JAX-jitted functions.
You have three options:
- static
lifterobject -> not an argument to the function@jax.jit def total_energy(u_free: jax.Array, u_top: float) -> jax.Array: lifter = lifter.at['top'].set(u_top) u_f = lifter.lift_from_zeros(u_free) return jnp.sum(u_f**2) - static argument
@partial(jax.jit, static_argnames=["lifter"]) def total_energy(u_free: jax.Array, u_top: float, lifter: Lifter) -> jax.Array: lifter = lifter.at['top'].set(u_top) u_f = lifter.lift_from_zeros(u_free) return jnp.sum(u_f**2) - traced/dynamic argument
and call the function with an updated lifter:
@jax.jit def total_energy(u_free: jax.Array, lifter: Lifter) -> jax.Array: u_f = lifter.lift_from_zeros(u_free) return jnp.sum(u_f**2)total_energy(u_free, lifter.at['top'].set(u_top))
Note
If you pass a lifter as a dynamic argument to a jitted function, only the runtime_values will be traced. All other elements of the lifter, are handled as static values (aux_data of a custom pytree).
Summary¤
Use Lifter when solving constrained systems in reduced coordinates.
- centralizes DOF bookkeeping for constraints
- provides explicit
lift/reduceoperations - keeps reduced-space code clean and JAX-friendly