Heat Transfer in a Plate with a Hole¤
This notebook demonstrates how to solve a transient heat transfer problem using tatva. We consider a plate with a circular hole, where temperature boundary conditions are applied to the top and bottom edges. The problem is solved using the Finite Element Method (FEM) with a matrix-free approach and an implicit time-stepping scheme.
Governing Equation¤
The transient heat conduction is governed by the heat equation:
where: - \(T\) is the temperature field, - \(\rho\) is the density, - \(c_p\) is the specific heat capacity, - \(k\) is the thermal conductivity, - \(q\) is the internal heat source (assumed zero here).
Energy Formulation¤
In tatva, we define the problem using an energy-centric approach. We define energy functionals for the stiffness and mass (storage) parts.
The Thermal Energy (stiffness) representing steady-state conduction is:
The Mass Energy (storage) representing heat capacity is:
The internal "forces" (thermal and inertial) are obtained by taking the derivative of these energies with respect to the temperature field \(T\). Specifically, the stiffness force is \(\mathbf{f}_{th} = \frac{\partial \Psi_{th}}{\partial \mathbf{T}}\) and the mass force is \(\mathbf{f}_{m} = \frac{\partial \Psi_{m}}{\partial \mathbf{T}}\).
Time Discretization¤
We use the Implicit Euler scheme for time discretization. For a time step \(\Delta t\), the equilibrium at step \(n+1\) is reached when the residual is zero:
where \(\mathbf{M}\) and \(\mathbf{K}\) are the mass and stiffness matrices (Hessians of the energy functionals). Rearranging for the increment \(\Delta \mathbf{T} = \mathbf{T}_{n+1} - \mathbf{T}_n\):
Instead of assembling these matrices explicitly, we use JAX's Automatic Differentiation (AD) to compute Jacobian-vector products, enabling a matrix-free solver using the Conjugate Gradient (CG) method.
Colab Setup (Install Dependencies)
# Only run this if we are in Google Colab
if 'google.colab' in str(get_ipython()):
print("Installing dependencies using uv...")
!pip install -q uv
!apt-get install -qq gmsh
!uv pip install --system matplotlib meshio equinox
!uv pip install --system "git+https://github.com/smec-ethz/tatva-docs.git"
print("Installation complete!")
import os
import time
from typing import Callable
import equinox as eqx
import gmsh
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import meshio
import numpy as np
from jax import Array
from jax_autovmap import autovmap
from matplotlib.axes import Axes
from tatva import Mesh, Operator, element
jax.config.update("jax_enable_x64", True)
Code for generating a refined plate with a hole and plotting the mesh
def plot_mesh(mesh: Mesh, ax: Axes | None = None) -> None:
if ax is None:
fig, ax = plt.subplots()
ax.tripcolor(
mesh.coords[:, 0],
mesh.coords[:, 1],
mesh.elements,
facecolors=np.ones(len(mesh.elements)),
cmap="managua",
edgecolors="k",
linewidth=0.2,
)
ax.set_aspect("equal")
def generate_refined_plate_with_hole(
width: float,
height: float,
hole_radius: float,
mesh_size_fine: float,
mesh_size_coarse: float,
):
mesh_dir = os.path.join(os.getcwd(), "meshes")
os.makedirs(mesh_dir, exist_ok=True)
output_filename = os.path.join(mesh_dir, "plate_hole_refined.msh")
gmsh.initialize()
gmsh.model.add("plate_with_hole_refined")
occ = gmsh.model.occ
rect = occ.addRectangle(0, 0, 0, width, height)
cx = width / 2.0
cy = height / 2.0
disk = occ.addDisk(cx, cy, 0, hole_radius, hole_radius)
out, _ = occ.cut([(2, rect)], [(2, disk)])
occ.synchronize()
surface_tag = out[0][1]
gmsh.model.addPhysicalGroup(2, [surface_tag], 1, name="domain")
boundaries = gmsh.model.getBoundary(out, oriented=False)
boundary_tags = [b[1] for b in boundaries]
gmsh.model.addPhysicalGroup(1, boundary_tags, 2, name="boundaries")
hole_curve_tags = []
for tag in boundary_tags:
xmin, ymin, zmin, xmax, ymax, zmax = gmsh.model.getBoundingBox(1, tag)
if xmin > 0 and xmax < width and ymin > 0 and ymax < height:
hole_curve_tags.append(tag)
gmsh.model.mesh.field.add("Distance", 1)
gmsh.model.mesh.field.setNumbers(1, "CurvesList", hole_curve_tags)
gmsh.model.mesh.field.setNumber(1, "NumPointsPerCurve", 100)
gmsh.model.mesh.field.add("Threshold", 2)
gmsh.model.mesh.field.setNumber(2, "InField", 1)
gmsh.model.mesh.field.setNumber(2, "SizeMin", mesh_size_fine)
gmsh.model.mesh.field.setNumber(2, "SizeMax", mesh_size_coarse)
gmsh.model.mesh.field.setNumber(2, "DistMin", 0.0)
gmsh.model.mesh.field.setNumber(2, "DistMax", hole_radius * 2.0)
gmsh.model.mesh.field.setAsBackgroundMesh(2)
gmsh.option.setNumber("Mesh.MeshSizeExtendFromBoundary", 0)
gmsh.option.setNumber("Mesh.MeshSizeFromPoints", 0)
gmsh.option.setNumber("Mesh.MeshSizeFromCurvature", 0)
gmsh.model.mesh.generate(2)
gmsh.write(output_filename)
gmsh.finalize()
_mesh = meshio.read(output_filename)
coords = _mesh.points[:, :2]
elements = _mesh.cells_dict["triangle"]
return Mesh(coords=coords, elements=elements)
lx = ly = 100.0
radius = lx / 4.0
mesh = generate_refined_plate_with_hole(
lx, ly, hole_radius=radius, mesh_size_fine=1.0, mesh_size_coarse=2.0
)
n_dofs = mesh.coords.shape[0]
plot_mesh(mesh)
Info : Meshing 1D...
Info : [ 0%] Meshing curve 5 (Ellipse)
Info : [ 30%] Meshing curve 6 (Line)
Info : [ 50%] Meshing curve 7 (Line)
Info : [ 70%] Meshing curve 8 (Line)
Info : [ 90%] Meshing curve 9 (Line)
Info : Done meshing 1D (Wall 0.0323048s, CPU 0.032675s)
Info : Meshing 2D...
Info : Meshing surface 1 (Plane, Frontal-Delaunay)
Info : Done meshing 2D (Wall 0.168509s, CPU 0.166183s)
Info : 5689 nodes 11383 elements
Info : Writing '/home/mohit/Documents/dev/tatva-docs/notebooks/examples/meshes/plate_hole_refined.msh'...
Info : Done writing '/home/mohit/Documents/dev/tatva-docs/notebooks/examples/meshes/plate_hole_refined.msh'

Physical Parameters and BCs¤
We define the material properties and the Dirichlet boundary conditions for the top and bottom edges.
k = 45e-3 # W/mm-K, thermal conductivity
rho = 7.8e-6 # kg/mm^3, density
cp = 496 # J/kg-K, specific heat capacity
T_top = 1.0
T_bottom = 0.1
op = Operator(mesh, element.Tri3())
top_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], ly, atol=1e-8))[0]
bottom_nodes = jnp.where(jnp.isclose(mesh.coords[:, 1], 0.0, atol=1e-8))[0]
fixed_dofs = jnp.concatenate([top_nodes, bottom_nodes])
Variational Forms and Residuals¤
Instead of assembling matrices, we define energy densities corresponding to the stiffness and mass contributions. Their derivatives with respect to temperature will provide the internal and inertial forces.
@autovmap(grad_T=1)
def thermal_energy_density(grad_T: Array) -> Array:
return 0.5 * k * jnp.dot(grad_T, grad_T)
@jax.jit
def total_thermal_energy(T_flat: Array) -> Array:
grad_T = op.grad(T_flat)
return op.integrate(thermal_energy_density(grad_T))
@autovmap(T=0)
def mass_energy_density(T: Array) -> Array:
return 0.5 * rho * cp * T * T
@jax.jit
def total_mass_energy(T_flat: Array) -> Array:
T_quad = op.eval(T_flat)
return op.integrate(mass_energy_density(T_quad))
compute_stiffness_force = jax.jacrev(total_thermal_energy)
compute_mass_force = jax.jacrev(total_mass_energy)
Matrix-Free Solver¤
We use the Conjugate Gradient (CG) method to solve the linear system \((M + \Delta t K) \Delta T = \Delta t (F_{ext} - K T_n)\) at each time step. The action of the matrix is computed using JAX's jvp (Jacobian-Vector Product).
@jax.jit
def K_times_x(T_prev: Array, dT: Array) -> Array:
dT_p = dT.at[fixed_dofs].set(0.0)
tangent = jax.jvp(compute_stiffness_force, (T_prev,), (dT_p,))[1]
return tangent.at[fixed_dofs].set(0.0)
@jax.jit
def M_times_x(T_prev: Array, dT: Array) -> Array:
dT_p = dT.at[fixed_dofs].set(0.0)
tangent = jax.jvp(compute_mass_force, (T_prev,), (dT_p,))[1]
return tangent.at[fixed_dofs].set(0.0)
@jax.jit
def A_times_x(dT: Array, T_prev: Array, dt: float) -> Array:
return M_times_x(T_prev, dT) + dt * K_times_x(T_prev, dT)
@eqx.filter_jit
def conjugate_gradient(A: Callable, b: Array, atol=1e-8, max_iter=500):
x = jnp.zeros_like(b)
r = b - A(x)
p = r
rsold = jnp.vdot(r, r)
iiter = 0
def cond_fun(state):
_, _, _, rsold, _, iiter = state
return jnp.logical_and(jnp.sqrt(rsold) > atol, iiter < max_iter)
def body_fun(state):
b, p, r, rsold, x, iiter = state
Ap = A(p)
alpha = rsold / jnp.vdot(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
rsnew = jnp.vdot(r, r)
p = r + (rsnew / rsold) * p
return b, p, r, rsnew, x, iiter + 1
_, _, _, _, x, iiter = jax.lax.while_loop(
cond_fun, body_fun, (b, p, r, rsold, x, iiter)
)
return x, iiter
Transient Simulation Loop¤
The initial condition is \(T=0\) everywhere in the plate. We run simulation for a total of 200 sec with a \(\Delta t=1\) sec.
# Initial condition: T=0 everywhere except at boundaries
T_init = jnp.zeros(n_dofs)
T_init = T_init.at[top_nodes].set(T_top)
T_init = T_init.at[bottom_nodes].set(T_bottom)
total_time = 200.0
dt = 1.0
n_steps = int(total_time / dt)
save_interval = 10
f_ext = jnp.zeros(n_dofs)
T_curr = T_init
T_history = [T_curr]
print(f"Running {n_steps} steps...")
t_start = time.perf_counter()
for step in range(n_steps):
fint = compute_stiffness_force(T_curr)
rhs = (f_ext - fint) * dt
rhs = rhs.at[fixed_dofs].set(0.0)
Amat = eqx.Partial(A_times_x, T_prev=T_curr, dt=dt)
dT, cg_iters = conjugate_gradient(Amat, rhs, atol=1e-10, max_iter=1000)
T_curr = T_curr + dT
if step % save_interval == 0:
T_history.append(T_curr)
print(f" Step {step:4d}/{n_steps}: T_min={T_curr.min():.4f}, T_max={T_curr.max():.4f}, CG iterations={cg_iters}")
print(f"Simulation complete in {time.perf_counter() - t_start:.2f} s")
Running 200 steps...
Step 0/200: T_min=0.0000, T_max=1.0000, CG iterations=82
Step 10/200: T_min=0.0019, T_max=1.0000, CG iterations=74
Step 20/200: T_min=0.0155, T_max=1.0000, CG iterations=71
Step 30/200: T_min=0.0376, T_max=1.0000, CG iterations=70
Step 40/200: T_min=0.0604, T_max=1.0000, CG iterations=69
Step 50/200: T_min=0.0767, T_max=1.0000, CG iterations=68
Step 60/200: T_min=0.0865, T_max=1.0000, CG iterations=67
Step 70/200: T_min=0.0951, T_max=1.0000, CG iterations=66
Step 80/200: T_min=0.1000, T_max=1.0000, CG iterations=66
Step 90/200: T_min=0.1000, T_max=1.0000, CG iterations=65
Step 100/200: T_min=0.1000, T_max=1.0000, CG iterations=64
Step 110/200: T_min=0.1000, T_max=1.0000, CG iterations=64
Step 120/200: T_min=0.1000, T_max=1.0000, CG iterations=63
Step 130/200: T_min=0.1000, T_max=1.0000, CG iterations=62
Step 140/200: T_min=0.1000, T_max=1.0000, CG iterations=62
Step 150/200: T_min=0.1000, T_max=1.0000, CG iterations=61
Step 160/200: T_min=0.1000, T_max=1.0000, CG iterations=60
Step 170/200: T_min=0.1000, T_max=1.0000, CG iterations=60
Step 180/200: T_min=0.1000, T_max=1.0000, CG iterations=59
Step 190/200: T_min=0.1000, T_max=1.0000, CG iterations=58
Simulation complete in 14.46 s
Results Visualization¤
We plot the evolution of temperature at a point \((0.5\ell_x, 0.8\ell_y)\).
monitor_pt = np.array([0.5 * lx, 0.8 * ly])
monitor_node = int(np.argmin(np.linalg.norm(np.array(mesh.coords) - monitor_pt, axis=1)))
snapshot_times = np.array([0.0] + [(i * save_interval + 1) * dt for i in range(len(T_history)-1)])
monitor_T = np.array([float(T[monitor_node]) for T in T_history])
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(snapshot_times, monitor_T, marker='o', markersize=3)
ax.set_xlabel("Time [s]")
ax.set_ylabel("Temperature")
ax.set_title(f"Temperature evolution at {mesh.coords[monitor_node]}")
ax.grid(True)
plt.show()

We create a animation to see the evolution of temperature in the plate.
Code for creating an animation of the temperature distribution over time
from matplotlib.animation import FuncAnimation, PillowWriter
T_arr = [np.array(T) for T in T_history]
T_min, T_max = min(T.min() for T in T_arr), max(T.max() for T in T_arr)
fig, ax = plt.subplots(figsize=(5, 5))
ax.set_aspect("equal")
tc = ax.tripcolor(
mesh.coords[:, 0], mesh.coords[:, 1], mesh.elements, T_arr[0],
vmin=T_min, vmax=T_max, cmap="managua", shading="gouraud"
)
fig.colorbar(tc, label="Temperature")
title = ax.set_title(f"t = {snapshot_times[0]:.1f} s")
def animate(i):
tc.set_array(T_arr[i])
title.set_text(f"t = {snapshot_times[i]:.1f} s")
return tc, title
anim = FuncAnimation(fig, animate, frames=len(T_history), blit=True, interval=100)
anim.save("heat_exchange.gif", writer=PillowWriter(fps=10))
plt.show()
