Source code for sigmaepsilon.mesh.plotting.mpl.utils

# -*- coding: utf-8 -*-
from ...config import __hasmatplotlib__

if not __hasmatplotlib__:  # pragma: no cover

    def triplotter(*_, **__):
        raise ImportError(
            "You need Matplotlib for this. Install it with 'pip install matplotlib'. "
            "You may also need to restart your kernel and reload the package."
        )

    def get_fig_axes(*_, **__):
        raise ImportError(
            "You need Matplotlib for this. Install it with 'pip install matplotlib'. "
            "You may also need to restart your kernel and reload the package."
        )

    def decorate_mpl_ax(*_, **__):
        raise ImportError(
            "You need Matplotlib for this. Install it with 'pip install matplotlib'. "
            "You may also need to restart your kernel and reload the package."
        )

else:
    from typing import Iterable, Callable, Any
    from functools import wraps

    import numpy as np
    from numpy import ndarray
    import matplotlib.pyplot as plt
    from matplotlib.figure import Figure, Axes
    from matplotlib.patches import Polygon
    from matplotlib.collections import PatchCollection

    from sigmaepsilon.mesh.triang import triangulate
    from sigmaepsilon.core.typing import issequence

    class TriPatchCollection(PatchCollection):
        def __init__(self, cellcoords, *args, **kwargs):
            pmap = map(lambda i: cellcoords[i], np.arange(len(cellcoords)))

            def fnc(points):
                return Polygon(points, closed=True)

            patches = list(map(fnc, pmap))
            super().__init__(patches, *args, **kwargs)

    def triplotter(plotter: Callable) -> Callable:
        @wraps(plotter)
        def inner(
            triobj: Any,
            *args,
            data: ndarray = None,
            title: str = None,
            label: str = None,
            fig: Figure = None,
            ax: Axes = None,
            axes: Iterable[Axes] = None,
            fig_kw: dict = None,
            **kwargs,
        ) -> Any:
            fig, axes = get_fig_axes(
                *args, data=data, ax=ax, axes=axes, fig=fig, fig_kw=fig_kw
            )

            if isinstance(triobj, tuple):
                coords, topo = triobj
                triobj = triangulate(points=coords[:, :2], triangles=topo)[-1]
                coords, topo = None, None

            if data is not None:
                if not len(data.shape) <= 2:
                    raise ValueError("Data must be a 1 or 2 dimensional array.")

                nD = 1 if len(data.shape) == 1 else data.shape[1]

                data = data.reshape((data.shape[0], nD))

                if not issequence(title):
                    title = nD * (title,)

                if not issequence(label):
                    label = nD * (label,)

                axobj = [
                    plotter(
                        triobj,
                        data[:, i],
                        fig=fig,
                        ax=ax,
                        title=title[i],
                        label=label[i],
                        **kwargs,
                    )
                    for i, ax in enumerate(axes)
                ]
                if nD == 1:
                    data = data.reshape(data.shape[0])
            else:
                axobj = plotter(triobj, ax=axes[0], fig=fig, title=title, **kwargs)

            return axobj

        return inner

    def get_fig_axes(
        *args,
        data=None,
        fig=None,
        axes=None,
        shape=None,
        horizontal=False,
        ax=None,
        fig_kw=None,
    ) -> tuple:
        """
        Returns a figure and an axes object.
        """
        if isinstance(ax, (tuple, list)):
            axes = ax

        if fig is not None:
            if axes is not None:
                return fig, axes
            elif ax is not None:
                return fig, (ax,)
        else:
            if fig_kw is None:
                fig_kw = {}

            if data is not None:
                nD = 1 if len(data.shape) == 1 else data.shape[1]

                if nD == 1:
                    try:
                        ax = args[0]
                    except Exception:
                        fig, ax = plt.subplots(**fig_kw)
                    return fig, (ax,)

                if fig is None or axes is None:
                    if shape is not None:
                        if isinstance(shape, int):
                            shape = (shape, 1) if horizontal else (1, shape)
                        assert nD == (
                            shape[0] * shape[1]
                        ), "Mismatch in shape and data."
                    else:
                        shape = (nD, 1) if horizontal else (1, nD)

                    fig, axes = plt.subplots(*shape, **fig_kw)

                if not isinstance(axes, Iterable):
                    axes = (axes,)

                return fig, axes
            else:
                try:
                    ax = args[0]
                except Exception:
                    fig, ax = plt.subplots(**fig_kw)
                return fig, (ax,)

        return None, None

[docs] def decorate_mpl_ax( *, fig=None, ax=None, aspect="equal", xlim=None, ylim=None, axis="on", offset=0.05, points=None, axfnc: Callable = None, title=None, suptitle=None, label=None, ): """ Decorates an axis using the most often used modifiers. """ assert ax is not None, ( "A matplotlib Axes object must be provided with " "keyword argument 'ax'!" ) if axfnc is not None: try: axfnc(ax) except Exception: raise RuntimeError("Something went wrong when calling axfnc.") if xlim is None: if points is not None: xlim = points[:, 0].min(), points[:, 0].max() if offset is not None: dx = np.abs(xlim[1] - xlim[0]) xlim = xlim[0] - offset * dx, xlim[1] + offset * dx if ylim is None: if points is not None: ylim = points[:, 1].min(), points[:, 1].max() if offset is not None: dx = np.abs(ylim[1] - ylim[0]) ylim = ylim[0] - offset * dx, ylim[1] + offset * dx ax.set_aspect(aspect) ax.axis(axis) if xlim is not None: ax.set_xlim(*xlim) if ylim is not None: ax.set_ylim(*ylim) if title is not None: ax.set_title(title) if label is not None: ax.set_xlabel(label) if fig is not None and suptitle is not None: fig.suptitle(suptitle) return ax