Source code for VolterraBasis.basis._fem_features

"""
This the main estimator module
"""
import numpy as np

from sklearn.base import TransformerMixin
from scipy.spatial import cKDTree

import sparse


class ElementFinder:
    """
    Class that find the correct element given a location
    """

    def __init__(self, mesh, mapping=None):
        # Get dimension from mesh
        self.dim = mesh.dim()
        # Transform mesh to triangular version if needed
        if self.dim == 1:
            self.max_point = np.max(mesh.p)  # To avoid strict < in np.digitize
            ix = np.argsort(mesh.p)
            self.bins = mesh.p[0, ix[0]]
            self.bins_idx = mesh.t[0]
            self.find = self.find_element_1D
        # elif self.dim == 2:
        #     self.find = self.find_element_2D
        #     self.mesh_finder = mesh.element_finder(mapping)
        else:
            self.tree = cKDTree(np.mean(mesh.p[:, mesh.t], axis=1).T)
            self.find = self.find_element_ND
            self.mapping = mesh._mapping() if mapping is None else mapping
            if self.dim == 2:  # We should also check the type of element
                self.inside = self.inside_2D
            elif self.dim == 3:
                self.inside = self.inside_3D

    def find_element_1D(self, X):
        """
        Assuming X is nsamples x 1
        """
        maxix = X[:, 0] == self.max_point
        X[maxix, 0] = X[maxix, 0] - 1e-10  # special case in np.digitize
        return np.argmax((np.digitize(X[:, 0], self.bins) - 1)[:, None] == self.bins_idx, axis=1)

    def find_element_2D(self, X):
        return self.mesh_finder(X[:, 0], X[:, 1])

    def find_element_ND(self, X, _search_all=False):
        tree_query = self.tree.query(X, 5)[1]
        element_inds = np.empty((X.shape[0],), dtype=int)
        for n, point in enumerate(X):  # Try to avoid loop
            i_e = tree_query[n, :]
            X_loc = self.mapping.invF((point.T)[:, None, None], tind=i_e)
            inside = self.inside(X_loc)
            element_inds[n] = i_e[np.argmax(inside, axis=0)]
        return element_inds

    def inside_2D(self, X):  # Do something more general from Refdom?
        """
        Say which point are inside the element
        """
        return (X[0] >= -np.finfo(X.dtype).eps) * (X[1] >= -np.finfo(X.dtype).eps) * (1 - X[0] - X[1] >= -np.finfo(X.dtype).eps)

    def inside_3D(X):
        """
        Say which point are inside the element
        """
        return (X[0] >= 0) * (X[1] >= 0) * (X[2] >= 0) * (1 - X[0] - X[1] - X[2] >= -np.finfo(X.dtype).eps)


[docs]class FEMScalarFeatures(TransformerMixin): """ Finite elements features for scalar basis """
[docs] def __init__(self, basis): # En vrai comme ça ne marche que pour les élements H1 je devrais juste construire la base localement à partir du mesh et de l'élément en vérifiant qu'il dérive bien de H1 # Mais ça peut marcher aussi pour les éléments globaux, bref on a le droit qu'aux éléments scalaires pour l'instant """ Wrapper to finite element basis from scikit-fem Parameters ---------- basis: skfem basis A finite element basis. Should be a scalar basis (H1 or global element) """ self.basis_fem = basis self.const_removed = False
def element_finder(self, x): # At first use, if not implement instancie the element finder if not hasattr(self, "element_finder_from_basis"): self.element_finder_from_basis = ElementFinder(self.basis_fem.mesh, mapping=self.basis_fem.mapping) # self.basis_fem.mesh.element_finder(mapping=self.basis_fem.mapping) return self.element_finder_from_basis.find(x) def fit(self, describe_result): self.element_finder_from_basis = ElementFinder(self.basis_fem.mesh, mapping=self.basis_fem.mapping) # Find tensorial order of the basis and adapt dimension in consequence test = self.basis_fem.elem.gbasis(self.basis_fem.mapping, self.basis_fem.mapping.F(self.basis_fem.mesh.p), 0)[0] if len(test.shape) == 3: # Vectorial basis raise NotImplementedError("Unsupported Element, please use the FEMVectorFeatures") self.dim_out_basis = test.shape[0] elif len(test.shape) == 2: # Scalar basis self.dim_out_basis = 1 self.n_output_features_ = self.basis_fem.N return self def basis(self, X): nsamples, dim = X.shape cells = self.element_finder(X) x = X.T # cells = self.basis_fem.mesh.element_finder(mapping=self.basis_fem.mapping)(*x) pts = self.basis_fem.mapping.invF(x[:, :, np.newaxis], tind=cells) phis = np.array([self.basis_fem.elem.gbasis(self.basis_fem.mapping, pts, k, tind=cells)[0] for k in range(self.basis_fem.Nbfun)]) # TODO: vérifier la shape return sparse.COO((np.tile(np.arange(nsamples), self.basis_fem.Nbfun), self.basis_fem.element_dofs[:, cells].flatten()), np.ravel(phis), shape=(nsamples, self.n_output_features_)).todense() # .todense(): Temporary workaround to the absence if einsum in sparse library def deriv(self, X, deriv_order=1): nsamples, dim = X.shape cells = self.element_finder(X) x = X.T # cells = self.basis_fem.mesh.element_finder(mapping=self.basis_fem.mapping)(*x) # TODO: Il faudrait trouver un moyen de mettre en cache ce résultat pts = self.basis_fem.mapping.invF(x[:, :, np.newaxis], tind=cells) phis = np.array([self.basis_fem.elem.gbasis(self.basis_fem.mapping, pts, k, tind=cells)[0].grad.transpose([1, 2, 0]) for k in range(self.basis_fem.Nbfun)]) # TODO: vérifier la shape et en extraire les diverses dimensions return sparse.COO( (np.tile(np.arange(nsamples), dim * self.basis_fem.Nbfun), np.tile(self.basis_fem.element_dofs[:, cells].flatten(), dim), np.repeat(np.arange(dim), nsamples * self.basis_fem.Nbfun)), np.ravel(phis), shape=(nsamples, self.n_output_features_, dim) # Le dernier array doit être 000011111222 ).todense() # .todense(): Temporary workaround to the absence if einsum in sparse library def hessian(self, X): # Only for Elementglobal raise NotImplementedError def antiderivative(self, X, order=1): raise NotImplementedError
# nsamples, dim = X.shape # features = np.zeros((nsamples, self.n_output_features_)) # return features if __name__ == "__main__": # pragma: no cover import matplotlib.pyplot as plt import skfem m = skfem.MeshLine(np.linspace(0, 5, 4 + 1)) base_elem = skfem.ElementLineP2() # skfem.ElementTriRT0() # # e = skfem.ElementVector(base_elem) basis_fem = skfem.CellBasis(m, base_elem) x_range = np.linspace(0, 5, 10).reshape(-1, 1) # basis = BSplineFeatures(6, k=3) basis = FEMScalarFeatures(basis_fem) basis.fit(x_range) print(x_range.shape) print("Basis") print(basis.basis(x_range).shape) print("Deriv") print(basis.deriv(x_range).shape) # print("Hessian") # print(basis.hessian(x_range).shape) # Plot basis x_range = np.linspace(0, 5, 150).reshape(-1, 1) # basis = basis.fit(x_range) # basis = LinearFeatures().fit(x_range) y = basis.basis(x_range).todense() z = basis.deriv(x_range).todense() plt.grid() for n in range(y.shape[1]): plt.plot(x_range[:, 0], y[:, n]) plt.plot(x_range[:, 0], z[:, n, 0]) plt.show()