Skip to content

Operator¤

tatva.operator.Operator ¤

Operator(
    mesh: Mesh,
    element: ElementT,
    batch_size: int | None = None,
    cache_weights: bool = False,
)

A class that provides an Operator for finite element method (FEM) assembly.

Parameters:

  • mesh ¤

    (Mesh) –

    The mesh containing the elements and nodes.

  • element ¤

    (ElementT) –

    The element type used for the finite element method.

  • batch_size ¤

    (int | None, default: None ) –

    Optional batch size for mapping operations over elements. If None, it defaults to the number of elements in the mesh. If many elements are present, setting a smaller batch size can reduce memory usage.

  • cache_weights ¤

    (bool, default: False ) –

    If True, the integration weights (the product of the determinant of the Jacobian and the quadrature weights) are computed once and cached for future use. This can speed up repeated integrations at the cost of increased memory usage.

Provides several operators for evaluating and integrating functions over the mesh, such as integrate, eval, and grad. These operators can be used to compute integrals, evaluate functions at quadrature points, and compute gradients of functions at quadrature points.

Example

from tatva import Mesh, Tri3, Operator mesh = Mesh.unit_square(10, 10) # Create a mesh element = Tri3() # Define an element type operator = Operator(mesh, element) nodal_values = jnp.array(...) # Nodal values at the mesh nodes energy = operator.integrate(energy_density)(nodal_values)

Methods:

  • get_integration_weights

    Returns the integration weights for the quadrature points of the mesh. This is

  • map

    Maps a function over the elements and quad points of the mesh.

  • map_over_elements

    Maps a function over the elements of the mesh.

  • integrate

    Integrate a nodal_array, quad_array, or numeric value over the mesh.

  • integrate_per_element

    Integrate a nodal_array, quad_array, or numeric value over the mesh. Returning the

  • eval

    Evaluates the nodal values at the quadrature points.

  • grad

    Computes the gradient of the nodal values at the quad points.

  • interpolate

    Interpolates nodal values to a set of points in the physical space.

  • quads

    Returns the quadrature points of the mesh in physical coordinates.

  • project

    Projects a given field onto the finite element space defined by the mesh and

Attributes:

  • cache_weights

    Returns True when the argument is true, False otherwise.

cache_weights class-attribute ¤

cache_weights = False

Returns True when the argument is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

get_integration_weights ¤

get_integration_weights() -> Array

Returns the integration weights for the quadrature points of the mesh. This is the product of the determinant of the Jacobian and the quadrature weights, which can be used for integrating functions over the mesh.

Returns:

  • Array

    A jax.Array with the integration weights at each quadrature point of each

  • Array

    element (shape: (n_elements, n_quad_points)).

map ¤

map(
    func: MappableOverElementsAndQuads[P, RT],
    *,
    element_quantity: Sequence[int] = (),
) -> MappedCallable[P, RT]

Maps a function over the elements and quad points of the mesh.

Returns a function that takes values at nodal points (globally) and returns the vmapped result over the elements and quad points.

Parameters:

  • func ¤
    (MappableOverElementsAndQuads[P, RT]) –

    The function to map over the elements and quadrature points.

  • element_quantity ¤
    (Sequence[int], default: () ) –

    Indices of the arguments of func that are quantities defined per element. The rest of the arguments are assumed to be defined at nodal points.

map_over_elements ¤

map_over_elements(
    func: Callable[P, RT],
    *,
    element_quantity: Sequence[int] = (),
) -> MappedCallable[P, RT]

Maps a function over the elements of the mesh.

Returns a function that takes values at nodal points (globally) and returns the vmapped result over the elements.

Parameters:

  • func ¤
    (Callable[P, RT]) –

    The function to map over the elements.

  • element_quantity ¤
    (Sequence[int], default: () ) –

    Indices of the arguments of func that are quantities defined per element. The rest of the arguments are assumed to be defined at nodal points.

integrate ¤

integrate(arg: Array | Numeric) -> jax.Array

Integrate a nodal_array, quad_array, or numeric value over the mesh.

Parameters:

  • arg ¤
    (Array | Numeric) –

    An array of nodal values (shape: (n_nodes, n_values)), an array of quadrature values (shape: (n_elements, n_quad_points, n_values)), or a numeric value (float or int).

Returns:

  • Array

    The integral of the nodal values or quadrature values over the mesh.

integrate_per_element ¤

integrate_per_element(arg: Array | Numeric) -> jax.Array

Integrate a nodal_array, quad_array, or numeric value over the mesh. Returning the integral per element.

Parameters:

  • arg ¤
    (Array | Numeric) –

    An array of nodal values (shape: (n_nodes, n_values)), an array of quadrature values (shape: (n_elements, n_quad_points, n_values)), or a numeric value (float or int).

Returns:

  • Array

    A jax.Array where each element contains the integral of the values in the

  • Array

    element (shape: (n_elements, n_values)).

eval ¤

eval(nodal_values: Array) -> jax.Array

Evaluates the nodal values at the quadrature points.

Parameters:

  • nodal_values ¤
    (Array) –

    The nodal values at the element's nodes (shape: (n_nodes, n_values))

Returns:

  • Array

    A jax.Array with the values of the nodal values at each quadrature point of

  • Array

    each element (shape: (n_elements, n_quad_points, n_values)).

grad ¤

grad(nodal_values: Array) -> jax.Array

Computes the gradient of the nodal values at the quad points.

Parameters:

  • nodal_values ¤
    (Array) –

    The nodal values at the element's nodes (shape: (n_nodes, n_values))

Returns:

  • Array

    A jax.Array with the gradient of the nodal values at each quadrature point

  • Array

    of each element (shape: (n_elements, n_quad_points, n_values, n_dim)).

interpolate ¤

interpolate(arg: Array, points: Array) -> jax.Array

Interpolates nodal values to a set of points in the physical space.

Parameters:

  • arg ¤
    (Array) –

    The nodal values to interpolate.

  • points ¤
    (Array) –

    The points to interpolate the function or nodal values to.

Returns:

  • Array

    A jax.Array with the interpolated values at the given points.

quads ¤

quads() -> jax.Array

Returns the quadrature points of the mesh in physical coordinates.

Same as op.eval(op.mesh.coords).

Returns:

  • Array

    An array with the quadrature points of the mesh in physical coordinates

  • Array

    (shape: (n_elements, n_quad_points, n_dim)).

project ¤

project(
    field: Array,
    colored_matrix: ColoredMatrix | None = None,
    lifter: Lifter | None = None,
) -> Array

Projects a given field onto the finite element space defined by the mesh and element.

Uses jax.experimental.sparse.linalg.spsolve to solve the linear system resulting from the projection. If colored_matrix is None (the default), a compatible colored matrix is assembled from self.mesh.elements. When a colored_matrix is passed explicitly, it must be compatible with the dimensions of the projected field and with the chosen fem space.

Parameters:

  • field ¤
    (Array) –

    The field to project, defined at the quadrature points (shape: (n_elements, n_quad_points, ...)).

  • colored_matrix ¤
    (ColoredMatrix | None, default: None ) –

    Optional colored matrix representing the finite element space. If omitted, it is constructed from self.mesh.elements.

  • lifter ¤
    (Lifter | None, default: None ) –

    Optional lifter used to lift and reduce between the full and reduced spaces.