Operator¤
tatva.operator.Operator
¤
Operator(mesh: Mesh, element: Element, batch_size: int | None = None)
A class that provides an Operator for finite element method (FEM) assembly.
Parameters:
-
(mesh¤Mesh) –The mesh containing the elements and nodes.
-
(element¤Element) –The element type used for the finite element method.
Provides several operators for evaluating and integrating functions over the mesh,
such as integrate, eval, and grad. These operators can be used to compute
integrals, evaluate functions at quadrature points, and compute gradients of
functions at quadrature points.
Example
from tatva import Mesh, Tri3, Operator mesh = Mesh.unit_square(10, 10) # Create a mesh element = Tri3() # Define an element type operator = Operator(mesh, element) nodal_values = jnp.array(...) # Nodal values at the mesh nodes energy = operator.integrate(energy_density)(nodal_values)
Methods:
-
map–Maps a function over the elements and quad points of the mesh.
-
map_over_elements–Maps a function over the elements of the mesh.
-
integrate–Integrate a nodal_array, quad_array, or numeric value over the mesh.
-
integrate_per_element–Integrate a nodal_array, quad_array, or numeric value over the mesh. Returning the
-
eval–Evaluates the nodal values at the quadrature points.
-
grad–Computes the gradient of the nodal values at the quad points.
-
interpolate–Interpolates nodal values to a set of points in the physical space.
map
¤
map(func: MappableOverElementsAndQuads[P, RT], *, element_quantity: Sequence[int] = ()) -> MappedCallable[P, RT]
Maps a function over the elements and quad points of the mesh.
Returns a function that takes values at nodal points (globally) and returns the vmapped result over the elements and quad points.
Parameters:
-
(func¤MappableOverElementsAndQuads[P, RT]) –The function to map over the elements and quadrature points.
-
(element_quantity¤Sequence[int], default:()) –Indices of the arguments of
functhat are quantities defined per element. The rest of the arguments are assumed to be defined at nodal points.
map_over_elements
¤
map_over_elements(func: Callable[P, RT], *, element_quantity: Sequence[int] = ()) -> MappedCallable[P, RT]
Maps a function over the elements of the mesh.
Returns a function that takes values at nodal points (globally) and returns the vmapped result over the elements.
Parameters:
integrate
¤
integrate(arg: Array | Numeric) -> jax.Array
Integrate a nodal_array, quad_array, or numeric value over the mesh.
Parameters:
-
(arg¤Array | Numeric) –An array of nodal values (shape: (n_nodes, n_values)), an array of quadrature values (shape: (n_elements, n_quad_points, n_values)), or a numeric value (float or int).
Returns:
-
Array–The integral of the nodal values or quadrature values over the mesh.
integrate_per_element
¤
integrate_per_element(arg: Array | Numeric) -> jax.Array
Integrate a nodal_array, quad_array, or numeric value over the mesh. Returning the integral per element.
Parameters:
-
(arg¤Array | Numeric) –An array of nodal values (shape: (n_nodes, n_values)), an array of quadrature values (shape: (n_elements, n_quad_points, n_values)), or a numeric value (float or int).
Returns:
-
Array–A
jax.Arraywhere each element contains the integral of the values in the -
Array–element (shape: (n_elements, n_values)).
eval
¤
eval(nodal_values: Array) -> jax.Array
Evaluates the nodal values at the quadrature points.
Parameters:
-
(nodal_values¤Array) –The nodal values at the element's nodes (shape: (n_nodes, n_values))
Returns:
-
Array–A
jax.Arraywith the values of the nodal values at each quadrature point of -
Array–each element (shape: (n_elements, n_quad_points, n_values)).
grad
¤
grad(nodal_values: Array) -> jax.Array
Computes the gradient of the nodal values at the quad points.
Parameters:
-
(nodal_values¤Array) –The nodal values at the element's nodes (shape: (n_nodes, n_values))
Returns:
-
Array–A
jax.Arraywith the gradient of the nodal values at each quadrature point -
Array–of each element (shape: (n_elements, n_quad_points, n_values, n_dim)).
interpolate
¤
interpolate(arg: Array, points: Array) -> jax.Array