Source code for FIAT.nodal_enriched

# Copyright (C) 2013 Andrew T. T. McRae, 2015-2016 Jan Blechta
#
# This file is part of FIAT (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

import numpy as np

from FIAT.polynomial_set import PolynomialSet
from FIAT.dual_set import DualSet
from FIAT.finite_element import CiarletElement

__all__ = ['NodalEnrichedElement']


[docs]class NodalEnrichedElement(CiarletElement): """NodalEnriched element is a direct sum of a sequence of finite elements. Dual basis is reorthogonalized to the primal basis for nodality. The following is equivalent: * the constructor is well-defined, * the resulting element is unisolvent and its basis is nodal, * the supplied elements are unisolvent with nodal basis and their primal bases are mutually linearly independent, * the supplied elements are unisolvent with nodal basis and their dual bases are mutually linearly independent. """ def __init__(self, *elements): # Test elements are nodal if not all(e.is_nodal() for e in elements): raise ValueError("Not all elements given for construction " "of NodalEnrichedElement are nodal") # Extract common data ref_el = elements[0].get_reference_element() expansion_set = elements[0].get_nodal_basis().get_expansion_set() degree = min(e.get_nodal_basis().get_degree() for e in elements) embedded_degree = max(e.get_nodal_basis().get_embedded_degree() for e in elements) order = max(e.get_order() for e in elements) mapping = elements[0].mapping()[0] formdegree = None if any(e.get_formdegree() is None for e in elements) \ else max(e.get_formdegree() for e in elements) value_shape = elements[0].value_shape() # Sanity check assert all(e.get_nodal_basis().get_reference_element() == ref_el for e in elements) assert all(type(e.get_nodal_basis().get_expansion_set()) == type(expansion_set) for e in elements) assert all(e_mapping == mapping for e in elements for e_mapping in e.mapping()) assert all(e.value_shape() == value_shape for e in elements) # Merge polynomial sets coeffs = _merge_coeffs([e.get_coeffs() for e in elements]) dmats = _merge_dmats([e.dmats() for e in elements]) poly_set = PolynomialSet(ref_el, degree, embedded_degree, expansion_set, coeffs, dmats) # Renumber dof numbers offsets = np.cumsum([0] + [e.space_dimension() for e in elements[:-1]]) entity_ids = _merge_entity_ids((e.entity_dofs() for e in elements), offsets) # Merge dual bases nodes = [node for e in elements for node in e.dual_basis()] dual_set = DualSet(nodes, ref_el, entity_ids) # CiarletElement constructor adjusts poly_set coefficients s.t. # dual_set is really dual to poly_set super(NodalEnrichedElement, self).__init__(poly_set, dual_set, order, formdegree=formdegree, mapping=mapping)
def _merge_coeffs(coeffss): # Number of bases members total_dim = sum(c.shape[0] for c in coeffss) # Value shape value_shape = coeffss[0].shape[1:-1] assert all(c.shape[1:-1] == value_shape for c in coeffss) # Number of expansion polynomials max_expansion_dim = max(c.shape[-1] for c in coeffss) # Compose new coeffs shape = (total_dim,) + value_shape + (max_expansion_dim,) new_coeffs = np.zeros(shape, dtype=coeffss[0].dtype) counter = 0 for c in coeffss: dim = c.shape[0] expansion_dim = c.shape[-1] new_coeffs[counter:counter+dim, ..., :expansion_dim] = c counter += dim assert counter == total_dim return new_coeffs def _merge_dmats(dmatss): shape, arg = max((dmats[0].shape, args) for args, dmats in enumerate(dmatss)) assert len(shape) == 2 and shape[0] == shape[1] new_dmats = [] for dim in range(len(dmatss[arg])): new_dmats.append(dmatss[arg][dim].copy()) for dmats in dmatss: sl = slice(0, dmats[dim].shape[0]), slice(0, dmats[dim].shape[1]) assert np.allclose(dmats[dim], new_dmats[dim][sl]), \ "dmats of elements to be directly summed are not matching!" return new_dmats def _merge_entity_ids(entity_ids, offsets): ret = {} for i, ids in enumerate(entity_ids): for dim in ids: if not ret.get(dim): ret[dim] = {} for entity in ids[dim]: if not ret[dim].get(entity): ret[dim][entity] = [] ret[dim][entity] += (np.array(ids[dim][entity]) + offsets[i]).tolist() return ret