Source code for sattoolbox.plots.plots

"""
This module contains the Plotter class, which is designed for creating and managing plots based on
data from pandas DataFrames.
It provides functionalities to:
    - initialize with a DataFrame, which is used for plotting
    - plot data from the DataFrame (different types of plots)
    - store and manage the plots
"""
from math import sqrt
from typing import Optional, Union, List
import warnings
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib

from sattoolbox.plots import cyclers

[docs] class Plotter: """ A class for convenient plotting from pandas DataFrames with matplotlib. The plot(kind='...') method allows for plotting any kind of plot supported by pandas.DataFrame.plot() and additional custom kinds (currently available: 'multi_scatter', 'R2_scatter, 'R2_hebin'). The method takes care of setting up the figure, axes, and styles, and allows for customization of the plot through keyword arguments. The plot_subplots method allows to plot multiple subplots in a single figure. SAT custom styles are available for different use cases (screen, presentation, paper). The styles all use the sat_base style and extend or ovveride from that. The base style mostly defines the default color cycle, and the setup of the axis and ticks. Presentation Styles: (Optimized size and fontsize (18) for presentations) - 'presentation': style for plots filling the whole width of the presentation slide. - 'presentation_half': style for plots filling half the width (2 plots or plot plus text) Paper style sheets: - Reduced font size: 10 - Linewidth, Markersize and ticks-size reduced - Various Paper Styles: - Sizes: regular (1.5 col --> 5.5 inch), half --> 5.5 inch, full --> 7.4 inch (see Elsevier author guidelines) - Black&White Paper (bw_paper): Color (3 shades of grey/black) and Linestyle Cycle changed, Grid turned off Screen Style: - Size fitted to be almost fullscreen on a Full HD Screen - Facecolor: White - Fontsize: 20 - Both minor and Major Grid activated Available styles are: - paper - paper_half - paper_full - bw_paper - bw_paper_half - bw_paper_full - screen - presentation - presentation_half Attributes: physical_quantities (pd.DataFrame): DataFrame containing information about physical quantities. plot_kinds_pd (list): List of valid plot kinds from pandas. plot_kinds_custom (list): List of custom plot kinds defined in this class. kind_specifics (dict): Dictionary containing specifics for each plot kind. sat_styles (list): List of available SAT styles. style_default (str): Default style for the plots. """ # Plot kind specifics: # - set special defaults # - Check if the combination of kind and xcols / ycols makes sense # - adjust the "kind_to_plot" for custom kinds (e.g. multi_scatter uses line plot) kind_specifics = { 'line': {'uses_cycler':True}, 'bar': {}, 'barh': {}, 'hist': { 'xcols_check': {'type' : type(None), 'message': 'xcols must be None for histograms (do not specify)'}, }, 'box': { 'defaults': {'legend': False}, 'xcols_check': {'type' : type(None), 'message': 'xcols must be None for box plot (do not specify)'},}, 'kde': {'uses_cycler':True}, 'density': {'uses_cycler':True}, 'area': { 'uses_cycler':True, 'xcols_check': {'type' : (type(None),str), 'message': 'xcols must be None or a string for area plot.'}, }, 'pie': { 'xcols_check': {'type' : type(None), 'message': 'xcols must be None for pie plot (do not specify)'}, 'ycols_check': {'type' : str, 'message': 'ycols must be a string for pie plot.'}, }, 'scatter': { 'defaults': {'legend': False}, 'xcols_check': {'type' : str, 'message': ('xcols must be a string for scatter plot. To plot '+ "multiple ycols over multiple xcols use kind='multi_scatter'")}, 'ycols_check': {'type' : str, 'message': ('ycols must be a string for scatter plot. To plot '+ "multiple ycols use kind='multi_scatter'")}, }, 'hexbin': { 'defaults': {'grid': False, 'legend': False, 'mincnt':np.nextafter(0, 1), 'gridsize': 50}, 'xcols_check': {'type' : str, 'message': 'xcols must be a string for hexbin plot.'}, 'ycols_check': {'type' : str, 'message': 'ycols must be a string for hexbin plot.'}, }, 'multi_scatter': { 'defaults': {'linestyle': '', 'marker': 'o'}, 'kind_to_plot': 'line', 'xcols_check': {'type' : (str, list, tuple), 'message': ('xcols must be a string or a list of strings for '+ "multi_scatter plot.")}, 'ycols_check': {'type' : (str, list, tuple), 'message': ('ycols must be a string or a list of strings for '+ "multi_scatter plot.")}, }, 'R2_scatter': { 'defaults': {'legend': False}, 'xcols_check': {'type' : str, 'message': ('xcols must be a string for R2_scatter plot. To plot '+ "multiple ycols over multiple xcols use kind='R2_multi_scatter'")}, 'ycols_check': {'type' : str, 'message': ('ycols must be a string for R2_scatter plot. To plot '+ "multiple ycols use kind='R2_multi_scatter'")}, }, 'R2_hexbin': { 'defaults': {'grid': False, 'legend': False, 'mincnt':np.nextafter(0, 1), 'gridsize': 50}, 'xcols_check': {'type' : str, 'message': 'xcols must be a string for R2_hexbin plot.'}, 'ycols_check': {'type' : str, 'message': 'ycols must be a string for R2_hexbin plot.'}, }, } # available SAT styles # defined as class attributes, so they can be accessed directly from outside sat_styles = ['screen', 'presentation', 'presentation_half', 'paper', 'paper_half', 'paper_full', 'bw_paper', 'bw_paper_half', 'bw_paper_full',] def __init__(self, physical_quantities: Optional[pd.DataFrame] = None, plot_settings: Optional[dict] = None, style_default: Optional[str] = 'screen' ) -> None: """ Initialize the Plotter object. Parameters ---------- physical_quantities : Optional[pd.DataFrame], optional Optional pandas DataFrame containing information about physical quantities, by default None. The DataFrame must have columns: 'label', 'color', 'symbol', 'unit'. If not provided, a default DataFrame will be used. plot_settings : Optional[dict], optional Optional dictionary containing default plot settings, by default None. Whatever is not provided, will be taken from a default dictionary. Note: These are settings, that are NOT part of plt.rcParams (those are controlled via the style parameters or set expliticely in the plot method). style_default : Optional[str], optional Optional string specifying the style of the plots, by default 'screen'. Available styles are: 'screen', 'presentation', 'paper' and any style from plt.style.available. Returns ------- None """ self._set_physical_quantities(physical_quantities) self._set_default_plot_settings(plot_settings) self.style_default = self._style_check_n_correct(style_default) self._last_context = None # to keep track of the last style context used self.plot_kinds_pd = _get_valid_plot_kinds_pandas() self.plot_kinds_custom = ['multi_scatter', 'R2_scatter', 'R2_hexbin'] # custom plot types that are not part of pandas #TODO Put this in a better place. This is a strange place to change the kind specifics dict Plotter.kind_specifics['multi_scatter']['plot_func'] = self._plot_multi_scatter Plotter.kind_specifics['R2_scatter']['plot_func'] = self._plot_R2_scatter Plotter.kind_specifics['R2_hexbin']['plot_func'] = self._plot_R2_hexbin # Check if custom kinds are valid (they should not overwrite the valid_kinds from pandas for custom_kind in self.plot_kinds_custom: if custom_kind in self.plot_kinds_pd: raise ValueError((f"Invalid custom plot kind '{custom_kind}'."+ f"Must NOT be one of the pandas kinds: {self.plot_kinds_custom}.")) ### PUBLIC METHODS ### # Note towards "typing": Union[] might be deprecated, at least as of python 3.9 it is possible # to write str | list[str]
[docs] def plot(self, df: pd.DataFrame, kind: Optional[str] = 'line', xcols: Optional[Union[str, List[str]]] = None, xlabel: Optional[str] = None, ycols: Optional[Union[str, List[str]]] = None, ylabel: Optional[str] = None, ax : Optional[plt.Axes] = None, fig : Optional[plt.Figure] = None, style : Optional[str] = None, **kwargs ) -> tuple[plt.Figure, Union[plt.Axes, tuple[plt.Axes]]]: """ Plot from a DataFrame using pandas DataFrame.plot(). This method adds some convenience features to the pandas plot method, such as handling of styles, titles, and legends. It is a general method for plotting and can be used for any kind of plot supported by pandas (see Plotter.plot_kinds_pd) plus some additional plot types (see Plotter.plot_kinds_custom). Parameters ---------- df : pd.DataFrame DataFrame containing the data to plot. kind : str, optional Type of plot to create. Check Plotter.plot_kinds_pd and Plotter.plot_kinds_custom for available plot types. Default is 'line'. xcols : str or list of str, optional List of the column(s) to use for the x-axis. If not provided, the index of the DataFrame will be used as x-values. If multiple ycols are provided, x_col must either be a list of the same lengths or a single column name that will be used for all ycols. xlabel : str, optional Label for the x-axis. Default is None. If None, the label will be inferred from the xcols name(s). ycols : str or list of str, optional Name of the column(s) to use for the y-axis. If not provided, all columns from the DataFrame will be plotted. ylabel : str, optional Label for the y-axis. Default is None. If None, the label will be inferred from the ycols name(s). ax : matplotlib.axes.Axes, optional The axes on which to plot. If not provided, a new figure and axes will be created. fig : matplotlib.figure.Figure, optional The figure on which to plot. If not provided, a new figure will be created. style : str, optional Style to use for the plot. If not provided, the default style of the Plotter instance will be used. **kwargs : dict Additional keyword arguments to pass to the plot. All keyword arguments accepted by pandas.DataFrame.plot are accepted. Returns ------- fig : matplotlib.figure.Figure The figure on which has been plotted. ax : matplotlib.axes.Axes or tuple of matplotlib.axes.Axes The axes (or tuple of axes, if secondary axis is used) on which has been plotted. """ # Translate the old names for y_label, x_label, ... to the new ones name_updates = {'y_label':'ylabel', 'x_label':'xlabel', 'y_cols':'ycols', 'x_cols':'xcols'} for old_name in name_updates: for key in kwargs.copy(): if old_name in key: # check if the key contains the old name new_key = key.replace(old_name, name_updates[old_name]) if new_key not in kwargs: warnings.warn(f"'{key}' is deprecated. Use '{new_key}' instead.") if new_key == 'ylabel': ylabel = kwargs.pop(key) elif new_key == 'xlabel': xlabel = kwargs.pop(key) elif new_key == 'ycols': ycols = kwargs.pop(key) elif new_key == 'xcols': xcols = kwargs.pop(key) else: kwargs[new_key] = kwargs.pop(key) else: raise ValueError(f"'{key}' and '{new_key}' are both provided."+ f"Please provide only '{new_key}'.") # check if "clip_figure" is in kwargs and process it (only used to supress clipping when # plotting subplots with plot_subplots()) clip_figure = kwargs.pop("clip_figure", True) # check plot kind if kind not in self.plot_kinds_custom + self.plot_kinds_pd: raise ValueError(f"Invalid plot kind '{kind}'. Must be one of {self.plot_kinds_pd} " + f"(plot kinds from pandas) or {self.plot_kinds_custom}") kind_to_plot = kind # apply kind_specific settings (from Class attribute kind_specifics) if kind in Plotter.kind_specifics: # Settings that are defaulted to something if 'defaults' in Plotter.kind_specifics[kind]: for key, value in Plotter.kind_specifics[kind]['defaults'].items(): if key not in kwargs: kwargs[key] = value # change kind_to_plot for custom kinds kind_to_plot = Plotter.kind_specifics[kind].get('kind_to_plot', kind) # check xcols type if 'xcols_check' in Plotter.kind_specifics[kind]: if not isinstance(xcols, Plotter.kind_specifics[kind]['xcols_check']['type']): raise ValueError(Plotter.kind_specifics[kind]['xcols_check']['message']) # check ycols type if 'ycols_check' in Plotter.kind_specifics[kind]: if not isinstance(ycols, Plotter.kind_specifics[kind]['ycols_check']['type']): raise ValueError(Plotter.kind_specifics[kind]['ycols_check']['message']) ycols = self._cols_check_n_correct(df, ycols) ylabel = self._infer_label(ycols, ylabel, axis_context='the y-axis') # keep provided xlabel (needed for check of xlabel with sec axis) xlabel_provided = xlabel # handling xcols if xcols is not None: xcols = self._cols_check_n_correct(df, xcols, required_lengths=[1, len(ycols)]) xlabel = self._infer_label(xcols, xlabel) if isinstance(xcols, (list, tuple)) and len(xcols) == 1: xcols = xcols[0] else: # xcols were not passed --> the index of the dataframe will be used as x-values # Nothing to be done except getting the name xlabel = self._infer_label([df.index.name] if df.index.name is not None else ['x'], xlabel, axis_context='the x-axis') # TODO (Idea): Infer colors for the lines from the physical_quantities DataFrame # Check if secondary y-axis is used and extract the settings for it sec_settings = self._get_secondary_ax_settings(kwargs, required_kwargs='ycols') if sec_settings: sec_ycols = self._cols_check_n_correct(df, sec_settings.pop('ycols')) sec_ylabel = self._infer_label(sec_ycols, sec_settings.pop('ylabel', None), axis_context='the secondary y-axis') # Handle xcols for sec axis if 'xcols' in sec_settings: sec_xcols = sec_settings.pop('xcols') else:# Default will be the same as for primary axis sec_xcols = xcols if sec_xcols is not None: sec_xcols = self._cols_check_n_correct(df, sec_xcols, required_lengths=[1, len(sec_ycols)]) sec_xlabel = sec_settings.pop('xlabel', xlabel_provided) if sec_xlabel is None: # This means there was no xlabel provided for either primary or sec axis # Only then infer a label sec_xlabel = self._infer_label(sec_xcols, None, axis_context='the secondary x-axis') # else: # sec_xlabel = sec_settings.pop('xlabel', xlabel) if len(sec_xcols)==1: sec_xcols = sec_xcols[0] else: # neither xcols nor sec_xcols were passed --> df.index will be used as x-values # check for sec_xlabel in sec_settings, otherwise use the primary xlabel sec_xlabel = sec_settings.pop('xlabel', xlabel) # Warn if sec_xlabel is inconsistent if sec_xlabel != xlabel:# and not xlabel_provided: warnings.warn((f"xlabel '{xlabel}' is inconsistent with sec_xlabel " + f"'{sec_xlabel}'. Setting xlabel = None.")) xlabel = None # take all kwargs that are in the form of rcParams and parse them to the correct form for # rcParams update_rcparams, parsed_rcparams = self._get_rcparams_from_kwargs(kwargs) # handle plot style style = self._style_check_n_correct(style) # Take title from kwargs and call _infer_title title = self._infer_title(kind.capitalize() + ' plot', ycols, style, kwargs.pop('title', 0)) # all steps that use rcParams (esp. all the plotting) should be done in this style context with plt.style.context(style): # override style with the rcParams settings provided in kwargs with plt.rc_context(update_rcparams): self._last_context = plt.rcParams.copy() specific_kwargs = {key: value for key, value in kwargs.items() if key not in update_rcparams and key not in parsed_rcparams} setup_settings, plot_settings, other_kwargs = self._get_setup_and_plot_settings( specific_kwargs) legend_settings = _extract_kwargs_with_prefix(other_kwargs, ['legend', 'leg']) label_settings = _extract_kwargs_with_prefix(other_kwargs, ['label', 'lab']) xlabel_settings = _extract_kwargs_with_prefix(other_kwargs, ['xlabel']) ylabel_settings = _extract_kwargs_with_prefix(other_kwargs, ['ylabel']) # setup figure and axes fig, ax = self._setup_plot(title=title, fig=fig, ax=ax, **setup_settings) # Create and set the cycler for cycable properties (e.g. color, linestyle) if Plotter.kind_specifics[kind_to_plot].get('uses_cycler', False): cycle_dict = plt.rcParams['axes.prop_cycle'].by_key() cycler, cycle_dict, other_kwargs = cyclers.create_cycle(cycle_dict, **other_kwargs) ax.set_prop_cycle(cycler) # actual plotting if len(ycols)==1: # if only one y_col is provided, it is passed as a string ycols = ycols[0] plot_func = Plotter.kind_specifics[kind].get('plot_func') if plot_func is not None: plot_func(df, xcols=xcols, ycols=ycols, ax=ax, **plot_settings, **other_kwargs) else: df.plot(x=xcols, y=ycols, kind=kind_to_plot, ax=ax, **plot_settings, **other_kwargs) # set labels self._setup_labels(ax=ax, xlabel=xlabel, ylabel=ylabel, xlabel_settings=xlabel_settings, ylabel_settings=ylabel_settings, **label_settings) # Plot the secondary y-axis if settings are provided if sec_settings: # Assume that all settings passed for the 'normal' axis are also valid for the # secondary axis --> copy them and update with sec_settings update_rcparams, parsed_rcparams = self._get_rcparams_from_kwargs(sec_settings) with plt.rc_context(update_rcparams): sec_setup_settings, sec_plot_settings, sec_other_kwargs = ( self._get_setup_and_plot_settings(sec_settings)) sec_setup_settings = {**setup_settings, **sec_setup_settings} sec_plot_settings = {**plot_settings, **sec_plot_settings} sec_other_kwargs = {**other_kwargs, **sec_other_kwargs} # Remove settings, that should not be repeated for the sec axis # (legend, label) _extract_kwargs_with_prefix(sec_other_kwargs, ['legend', 'leg']) sec_label_settings = _extract_kwargs_with_prefix(sec_other_kwargs, ['label', 'lab']) sec_xlabel_settings = _extract_kwargs_with_prefix(sec_other_kwargs, ['xlabel']) sec_ylabel_settings = _extract_kwargs_with_prefix(sec_other_kwargs, ['ylabel']) # setup figure and axes ax = (ax, ax.twinx()) # create a secondary y-axis self._setup_plot(title=None, fig=fig, ax=ax[1], **sec_setup_settings) # Create and set the prop_cycler if Plotter.kind_specifics[kind_to_plot].get('uses_cycler', False): n_lines = len(ycols) if not isinstance(ycols, str) else 1 cycler, cycle_dict, other_kwargs = cyclers.create_cycle(cycle_dict, n_lines=n_lines, **sec_other_kwargs) ax[1].set_prop_cycle(cycler) # actual plotting on the secondary y-axis if plot_func is not None: plot_func(df, xcols=sec_xcols, ycols=sec_ycols, ax=ax[1], **sec_plot_settings, **sec_other_kwargs) else: df.plot(x=sec_xcols, y=sec_ycols, kind=kind_to_plot, ax=ax[1], **sec_plot_settings, **sec_other_kwargs) # setup the labels for the secondary axis self._setup_labels(ax=ax[1], xlabel=sec_xlabel, ylabel=sec_ylabel, is_sec_axis=True, xlabel_settings=sec_xlabel_settings, ylabel_settings=sec_ylabel_settings, **dict(label_settings, **sec_label_settings)) # Post-Processing section if plot_settings.get('legend', True): create_legend(ax, **legend_settings) # clip figure if no explicit figsize is given and if ax has a fixed aspect ratio if clip_figure and not 'figsize' in kwargs.keys(): # handle primary and potential secondary axis if isinstance(ax, plt.Axes): has_fixed_aspect = ax.get_aspect() != 'auto' else: has_fixed_aspect = any(a.get_aspect() != 'auto' for a in ax) if has_fixed_aspect: _clip_fig(fig) return fig, ax
def _plot_multi_scatter(self, df: pd.DataFrame, xcols, ycols, ax, **kwargs ): """ Custom plot function for multi_scatter plots (that is scatter plots with more than one pair of x and y involved). This function is called from plot() if the kind is 'multi_scatter'. Parameters ---------- df : pd.DataFrame DataFrame containing the data to plot. xcols : str or list of str Name of the column(s) to use for the x-axis. ycols : str or list of str Name of the column(s) to use for the y-axis. ax : matplotlib.axes.Axes The axes on which to plot. **kwargs : dict Additional keyword arguments to pass to the plot (e.g., linestyle, linewidth, etc.). """ # Wrap the xcols and ycols in lists if they are not already if isinstance(xcols, str): xcols = [xcols] if isinstance(ycols, str): ycols = [ycols] # Duplicate xcols for each y_col if only one x_col is provided # (This needs xcols and ycols to be in lists) xcols = xcols*len(ycols) if len(xcols) == 1 else xcols # Construct a new DataFrame with the xcols as index and ycols as values # This is done to allow for easy plotting of multiple ycols against the same x_col(s). df_new = pd.concat([df.set_index(x_i)[y_i] for x_i, y_i in list(zip(xcols, ycols))], axis=1) # Keep xlabel (was set in setup_plot() but would be changed here if not given explicitely) # xlabel = ax.get_xlabel() # df_new.plot(y=ycols, kind='line', ax=ax, xlabel = xlabel, **kwargs) df_new.plot(y=ycols, kind='line', ax=ax, **kwargs) def _plot_R2_scatter(self, df: pd.DataFrame, xcols: str, ycols: str, ax: plt.Axes, metrics: Optional[Union[str, List[str]]] = 'all', **kwargs): """ Custom plot function for R2_scatter plots for comparison of modelled and measured values. The plot is a scatter plot with a squared layout, an angle bisector line and the metrics R2, MAE and RMSE are printed into the plot area. This function is called from plot() if the kind is 'R2_scatter'. Parameters ---------- df : pd.DataFrame DataFrame containing the data to plot. xcols : str Name of the column to use for the x-axis (modelled values). ycols : str Name of the column to use for the y-axis (measured values). ax : matplotlib.axes.Axes The axis on which to plot. metrics : str or list of str, optional Metrics to display on the plot. Can be 'all', 'R2', 'MAE', 'RMSE' or a list of these. To suppress the metrics, set to None or False. **kwargs : dict Additional keyword arguments to pass to the plot (e.g., markerstyle etc). """ ax.set_aspect('equal', adjustable='box') # set the size of the ax to be square self._add_angle_bisector_line(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax) if metrics: self._add_metrics_to_R2_plot(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax, metrics=metrics) # plot the data df.plot(x=xcols, y=ycols, kind='scatter', ax=ax, **kwargs) def _plot_R2_hexbin(self, df: pd.DataFrame, xcols: str, ycols: str, ax: plt.Axes, metrics: Optional[Union[str, List[str]]] = 'all', **kwargs): """ Custom plot function for R2_hexbin plots for comparison of modelled and measured values. The plot is a hexbin plot with a squared layout, an angle bisector line and the metrics R2, MAE and RMSE are printed into the plot area. This function is called from plot() if the kind is 'R2_hexbin'. Parameters ---------- df : pd.DataFrame DataFrame containing the data to plot. xcols : str Name of the column to use for the x-axis (modelled values). ycols : str Name of the column to use for the y-axis (measured values). ax : matplotlib.axes.Axes The axis on which to plot. metrics : str or list of str, optional Metrics to display on the plot. Can be 'all', 'R2', 'MAE', 'RMSE' or a list of these. To suppress the metrics, set to None or False. **kwargs : dict Additional keyword arguments to pass to the plot (e.g., markerstyle etc). """ # set the size of the ax to be square ax.set_aspect('equal', adjustable='box') self._add_angle_bisector_line(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax) if metrics: self._add_metrics_to_R2_plot(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax, metrics=metrics) # plot the data df.plot(x=xcols, y=ycols, kind='hexbin', ax=ax, **kwargs) def _add_angle_bisector_line(self, x, y, ax): """ Add an angle bisector line to the plot. Parameters ---------- x : array-like x-values of the data. y : array-like y-values of the data. ax : matplotlib.axes.Axes The axis on which to plot. """ min_val = min(x.min(), y.min()) max_val = max(x.max(), y.max()) ax.plot([min_val, max_val], [min_val, max_val], color='black', linestyle='--', linewidth=1, label='Angle bisector') def _add_metrics_to_R2_plot(self, x, y, ax, metrics = 'all'): """ Add metrics to the R2 plot. Parameters ---------- x : array-like x-values of the data. y : array-like y-values of the data. ax : matplotlib.axes.Axes The axis on which to plot. metrics : str or list of str, optional Metrics to display on the plot. Can be 'all', 'R2', 'MAE', 'RMSE' or a list of these. Default is 'all'. """ try: from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error except ImportError: warnings.warn("scikit-learn is required to print metrics on R2 plots. "+ "Please install it via 'pip install scikit-learn' or 'conda install "+ "scikit-learn'. Metrics are not printed now.") return # calculate metrics (R2, MAE, RMSE) dict_metrics = { 'R2': r2_score(x, y), 'MAE': mean_absolute_error(x, y), 'RMSE': root_mean_squared_error(x, y) } if metrics == 'all': metrics = dict_metrics.keys() if isinstance(metrics, str): metrics = [metrics] metrics_text = '' for metric in metrics: if metric not in dict_metrics: raise ValueError(f"Invalid metric '{metric}'. Must be one of {dict_metrics.keys()}") metric_value = dict_metrics[metric] metrics_text += f"{metric}: {metric_value:.2f}\n" if metric == 'R2' and metric_value < 0: message = ( f"R² = {metric_value:.2f} < 0. This indicates that the model is worse than a " "horizontal line. For more information check the documentation of the r2_score " "function at: " "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html" ) warnings.warn(message) # add metrics to the plot fontsize = plt.rcParams.get('xtick.labelsize') ax.text(0.05, 0.95, metrics_text, transform=ax.transAxes, fontsize=fontsize, va='top')
[docs] def plot_subplots(self, df: pd.DataFrame, plots: dict, layout: Optional[tuple] = None, style: Optional[str] = None, figsize: Optional[tuple] = None, hspace: Optional[float] = 0.05, wspace: Optional[float] = 0.05, **kwargs: Optional[dict] ) -> tuple[plt.Figure, np.ndarray]: """ Plot multiple subplots in a single figure. Parameters ---------- df : pd.DataFrame DataFrame containing the data to plot. plots : dict Dictionary where keys are subplot titles (showing the title can be supressed for each subplot with "title = None" within respective entry in the plots dict) and values are dicts containing keyword arguments for the plot (including plot kind). More information on available plot kinds and valid arguments can be obtained from the plot() method. layout : tuple, optional Tuple containing the layout of the subplots (number of rows, number of columns). If not provided, the layout will be inferred from the number of plots. style : str, optional Style to use for the plot. If not provided, the default style of the Plotter instance will be used. figsize : tuple, optional Tuple containing the size of the figure (width, height) in inches. If not provided, the default size from the used style will be used. hspace : float, optional Height space between the subplots as a fraction of size of the subplot group as a whole. If there are more than two rows, the hspace is shared between them. Default is 0.05. More information on padding: https://matplotlib.org/stable/users/explain/axes/constrainedlayout_guide.html#padding-and-spacing wspace : float, optional Width space between the subplots as a fraction of size of the subplot group as a whole. If there are more than two columns, the wspace is shared between them. Default is 0.05. More information on padding: https://matplotlib.org/stable/users/explain/axes/constrainedlayout_guide.html#padding-and-spacing **kwargs : optional Additional keyword arguments to pass to the subplot creation of matplotlib.pyplot.subplots(). For more information see: https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.subplots.html Returns ------- fig : matplotlib.figure.Figure The figure containing the subplots. axs : numpy.ndarray of matplotlib.axes.Axes Array of axes objects representing the subplots. """ #TODO: It should be possible to provide a plot type once for all subplots #TODO: Allow to set rcParams other than via style for the entire figure plots_dict = deepcopy(plots) # original plots dict shall not be changed by this method # infer layout from number of plots if layout not provided explicitly if layout is None: num_plots = len(plots_dict) layout = (num_plots, 1) if num_plots <= 3 else (2, num_plots // 2 + num_plots % 2) # check if layout makes sense if layout[0]*layout[1] < len(plots_dict): raise ValueError(f"Layout {layout} is too small for the number of plots {len(plots_dict)}.") if layout[0]*layout[1] > len(plots_dict): warnings.warn(f"Layout {layout} is too big for the number of plots {len(plots_dict)}.") # handle plot style style = self._style_check_n_correct(style) # all steps that work on the plot should be done in this style context with plt.style.context(style): # setup figure and axes fig, axs = plt.subplots(layout[0], layout[1], **kwargs) # set figsize if given explicitely if figsize: fig.set_size_inches(figsize) # iterate over the plots to create the subplots for ax, (title, plot_kwargs) in zip(axs.flatten(), plots_dict.items()): title = plot_kwargs.pop('title', title) self.plot(df=df, fig=fig, ax=ax, clip_figure=False, title=title, **plot_kwargs) # Adjust the layout to ensure appropriate spacing between the subplots fig.get_layout_engine().set(hspace=hspace, wspace=wspace) # clip figure if no explicit figsize is given and if ax has a fixed aspect if figsize is None: has_fixed_aspect = any(a.get_aspect() != 'auto' for a in axs.flatten()) if has_fixed_aspect: _clip_fig(fig) return fig, axs
### PRIVATE METHODS ### def _get_rcparams_from_kwargs(self, kwargs): """ Extract rcParams from kwargs and return them in the correct form for rcParams. For convenience, aliases are provided for some parameters. Currently available: - fontsize: font.size (Note: fontsize sets the 'medium' fontsize, not all) - grid: axes.grid (Note: grid is a boolean, not a string) Parameters ---------- kwargs : dict Dictionary containing keyword arguments. Returns ------- dict Dictionary containing rcParams settings. list List of rcParams settings that were parsed from kwargs. """ # parameter names that shall be known in the plotter and be parsed to rcParams parse_rcparams = {'fontsize' : 'font.size', 'grid' : 'axes.grid'} # extract rcParams from kwargs and set them in the correct form for rcParams update_rcparams = {value: kwargs[key] for key, value in parse_rcparams.items() if key in kwargs and key in parse_rcparams} parsed_rcparams = [key for key in parse_rcparams if key in kwargs] # extract all other rcParams from kwargs, that are already in the form of rcParams update_rcparams.update({key: value for key, value in kwargs.items() if key in plt.rcParams}) return update_rcparams, parsed_rcparams def _style_check_n_correct(self, style: Optional[str]) -> str: """ Handle the style for the plot. Parameters ---------- style : str, optional Style(s) to use for the plot(s). If None, the default style will be used. Returns ------- str The style to use for the plot. """ # check if style exists and set it if style is None: style = self.style_default elif style in self.sat_styles: path = 'sattoolbox.plots.plot_styles.' style = [path+'sat_base', path+style]#'sattoolbox.plots.plot_styles.'+style elif style in plt.style.available: pass else: if not style.endswith('.mplstyle'): style += '.mplstyle' try: with plt.style.context(style): pass except OSError as exc: raise ValueError(f"style must be one of {self.sat_styles}, or any of "+ f"plt.style.available. You provided '{style}'. Custom style sheets"+ " must be in the working dir or the full path must be passed." ) from exc return style def _get_setup_and_plot_settings(self, kwargs): """ Split the kwargs into setup and plot settings. Parameters ---------- kwargs : dict Dictionary containing keyword arguments. Returns ------- setup_settings : dict Dictionary containing keyword arguments for setting up the plot. plot_settings : dict Dictionary containing keyword arguments for plotting the data. other_kwargs : dict Dictionary containing keyword arguments that are not used for setting up the plot or plotting the data. """ # only the following arguments will be returned as setup settings or plot settings setup_settings_names = ['figsize', 'layout'] plot_settings_names = ['legend'] # get all settings: provided settings extend or overwrite default settings all_settings = self.default_plot_settings.copy() all_settings.update(kwargs) # some rcParams are ignored by pandas. Thus, they need to be extracted from plt.rcParams # and will be set explicitely rcparams_ignored_by_pandas = {} # key = plot setting, value = rcParam for key, value in rcparams_ignored_by_pandas.items(): if key not in all_settings: all_settings[key] = plt.rcParams[value] setup_settings = {key: all_settings[key] for key in setup_settings_names if key in all_settings} plot_settings = {key: all_settings[key] for key in plot_settings_names if key in all_settings} other_kwargs = {key: value for key, value in all_settings.items() if key not in setup_settings.keys() and key not in plot_settings.keys()} return setup_settings, plot_settings, other_kwargs def _get_secondary_ax_settings(self, kwargs: dict, required_kwargs: Optional[List] = None) -> dict: """ Find the settings for the secondary y-axis in the kwargs. All settings that start with 'sec_' or 'secondary_' are considered settings for the secondary y-axis. They are removed from the passed kwargs dictionary and returned as a separate dicitonary where the prefix is removed from the keys. If required kwargs are not provided, the secondary axis will not be created and a warning will be raised. Parameters ---------- kwargs : dict Dictionary containing keyword arguments. required_kwargs : list, optional List of required keyword arguments for the secondary y-axis. This might be dependent on the calling method. A Timeseries plot might require 'sec_ycols', while a Scatter plot might also require 'sec_xcols'. Returns ------- dict Dictionary containing the settings for the secondary y-axis. """ #Settings that are defaulted to something else for the secondary axis than for the primary sec_default_settings = {'grid': False} if isinstance(required_kwargs, str): required_kwargs = [required_kwargs] # A secondary y-axis is used if 'sec_ycols' is provided in kwargs # kwargs that start with 'sec_' or 'secondary_' are considered settings for the sec y-axis # They are removed from the original kwargs and returned as sec_settings # The prefix 'sec_' or 'secondary_' is removed from the keys sec_settings = {key.split('_', maxsplit=1)[1]: kwargs.pop(key) for key in kwargs.copy().keys() if key.startswith(('sec_','secondary_'))} if len(sec_settings)>0: missing_keys = [key for key in required_kwargs if key not in sec_settings] # set default settings for the sec axis (only if they are not explicitely provided) for key, value in sec_default_settings.items(): if key not in sec_settings: sec_settings[key] = value if len(missing_keys)>0: warnings.warn('You provided settings for a secondary y-axis, but did not provide '+ f'the following required kwargs: {required_kwargs}. '+ "The secondary y-axis will not be created.") sec_settings = {} # reset sec_settings if required ones were not provided return sec_settings def _set_physical_quantities(self, physical_quantities: Optional[pd.DataFrame] = None): """ Set the physical quantities that shall be known to the Plotter. This is used for inferring labels for the axes. Parameters ---------- physical_quantities : pd.DataFrame, optional DataFrame containing information about physical quantities. The DataFrame must have columns: 'label', 'color', 'symbol', 'unit'. If not provided, a default DataFrame will be used. Raises ------ ValueError If the physical_quantities DataFrame does not have the required columns. Returns ------- None """ #TODO: it could be nicer to load this from a file if physical_quantities is None: # default physical_quantities DataFrame: # label = Name + symobl + "in" + unit # color = default color plots, should be unique for each quantity # symbol = physical symbol for the quantity # unit = unit of the quantity (SI or other standard unit) self.physical_quantities = pd.DataFrame(data={ 'T' : ['Temperature T in °C', 'red', r'$T_{SL}$', '°C'], 'dT' : [r'Temperature difference $\Delta$T in K', 'red', r'$\DeltaT$', 'K'], 'p' : ['Pressure p in bar', 'grey', '$p$', 'bar'], 'dp' : [r'Differential pressure $\Delta$p [bar]', 'lightgrey', r'$\Delta$p', 'bar'], 'm_flow' : [r'Mass flow $\dot{m}$ in kg/s', 'green', r'$\dot{m}$', 'kg/s'], 'Q_flow' : [r'Heat flow $\dot{Q}$ in W', 'orange', r'$\dot{Q}$', 'W'], }, index=['label', 'color', 'symbol', 'unit']) else: # check if physical_quantities has the required rows; raise value error otherwise required_rows = ['label', 'color', 'symbol', 'unit'] if not isinstance(physical_quantities, pd.DataFrame): raise TypeError("physical_quantities must be a pandas DataFrame.") if not set(required_rows).issubset(physical_quantities.index): raise ValueError("physical_quantities DataFrame must have rows: 'label', "+ "'color', 'symbol', 'unit'") self.physical_quantities = physical_quantities def _set_default_plot_settings(self, plot_settings: Optional[dict] = None): """ Set default plot settings that differ from matplotlib defaults. Parameters ---------- plot_settings : dict, optional Dictionary containing default plot settings. Whatever is not provided, will be taken from a default dictionary. Returns ------- None """ # Add any default settings # Important: Only settings, that are not part of matplotlib.rcParams, go here. # Any rcParams should be set through styles. default_plot_settings = { # 'lab_hor_ylabels': False, } # create dummy plot_settings if not provided if plot_settings is None: plot_settings = {} # warn if additional settings are provided (they might be ignored by the plot methods) additional_settings = [key for key in plot_settings if key not in default_plot_settings] if len(additional_settings) > 0: warnings.warn("You are introducing the following additional settings to the "+ "'default_plot_settings' (they might be ignored by the plot methods): "+ str(additional_settings)) # integrate specific settings into the default settings default_plot_settings.update(plot_settings) # set the default settings of the Plotter object self.default_plot_settings = default_plot_settings def _cols_check_n_correct(self, df : pd.DataFrame, cols: Optional[Union[str, List[str]]] = None, required_lengths: Optional[List[int]] = None) -> list: """ Return a proper cols list. - If cols is None, all columns from the DataFrame will be used. - If cols is a string, it will be wrapped in a list containing the string. - If cols is a list, it will be left as is. - The function checks if all the names in cols exist in the DataFrame. - If any of the columns in cols do not exist in the DataFrame, a warning will be given. - If cols is empty after removing non-existing columns, a ValueError will be raised. - If required_lengths is provided, the function will check if the length of cols is in the required length. If not, a ValueError will be raised. Parameters ---------- df : pd.DataFrame The DataFrame to check the columns against. cols : str or list of str, optional The columns to be checked and corrected. required_lengths : list of int, optional The list of acceptable length of cols. If not provided, no check will be done. Returns ------- list of str The corrected cols list. """ if isinstance(cols, str): cols = [cols] if cols is None: cols = list(df.columns) if not all(col in df.columns for col in cols): missing_cols = [col for col in cols if col not in df.columns] warnings.warn("Some of the columns in cols do not exist in the DataFrame:\n"+ f"{missing_cols} \n"+ "Those columns will be removed from cols.") cols = [col for col in cols if col in df.columns] if cols == []: raise ValueError("cols is empty after removing non-existing columns. This means\n"+ "that either the DataFrame is empty or you passed only cols that\n"+ "do not exist in the DataFrame.") if required_lengths: if len(cols) not in required_lengths: raise ValueError(f"cols must have a length of {required_lengths}. You provided "+ f"cols {cols} with a length of {len(cols)}.") return cols def _infer_label(self, cols: list, label: Optional[str] = None, axis_context: Optional[str] = 'this axis' ) -> str: """ Try to infer a proper label, if it is None from the variables names in cols. Parameters ---------- cols : list List of column names that are plotted on the axis. label : str, optional Label for the axis. Default is None. Returns ------- str The inferred label for the y-axis. """ if label is False: # If the label is False, no label will be used return None if label is not None: # If a label was explcitely provided, it will be used as label return label # Begin constructing a warning message warn_mess = f"You did not pass a label name for {axis_context}." label_dict = self.physical_quantities.loc['label', :].to_dict() # find matches between cols and label_dict keys matches = [] for y_col in cols: this_match = _find_longest_match(label_dict.keys(), y_col) if (this_match is not None) and this_match not in matches: matches += [this_match] if len(matches) == 1: # One match is found --> use it as label inf_label = label_dict[matches[0]] warn_mess += f" Label: {inf_label} was inferred from the column names." else: # No match or multiple matches were found --> return None inf_label = None warn_mess += f" Inferring from the column names {cols} was unsuccessful." if len(matches) == 0: warn_mess += " (No match)" if len(cols) == 1: inf_label = cols[0] # If only one column is given, use it as label else: warn_mess += " (Multiple matches)" warnings.warn(warn_mess) return inf_label # TODO: Title inference should also take xcols into account (e.g. for scatter: 'xcol over ycol') def _infer_title(self, plot_type: str, ycols: list, style: str, title: Optional[Union[str, bool]] = None ) -> str: """ Try to infer a proper title, if it is None or True for the plot from the plot type Parameters ---------- plot_type : str A string that is passed from the calling plot function to name the plot type for a title e.g. 'Timeseries plot', 'Histogram', 'Scatter plot' ycols : list List of column names that are plotted on the axis. Used in the construction of an inferred title. style : str The style of the plot. Used to check if a title should be inferred. title : str, optional Title for the plot. Default is None. If True or 'infer', the title will be inferred. If False or None, no title will be used. If string is passed, it will be used as title. Any not-string or any bool(str) != True (e.g. '') will lead to a behavior depending on the style. Returns ------- str The inferred title for the plot. If no title shall be used, None is returned. """ # If the title is a string and not 'infer', it will be used as title if title and isinstance(title, str) and title != 'infer': return title # Check if the title shall be inferred infer_title = (bool(title is not False and title is not None) and bool(title is True or title == 'infer' or style[1].split('.')[-1] == 'screen')) if infer_title: title = plot_type + " of " + ', '.join([str(col) for col in ycols]) else: # Return None if no title shall be inferred title = None return title def _setup_plot(self, title: str, ax: Optional[plt.Axes] = None, fig: Optional[plt.Figure] = None, **setup_settings): """ Set up a plot with the given title, x-axis label, y-axis label. The method will create a new figure and axes if none are provided. Parameters: ---------- title : str The title of the plot. xlabel : str The label for the x-axis. ylabel : str The label for the y-axis. ax : matplotlib.axes.Axes, optional The axes on which to plot. If not provided, a new figure and axes will be created. fig : matplotlib.figure.Figure, optional The figure on which to plot. If not provided, a new figure and axes will be created. setup_settings : dict Additional keyword arguments to pass to the plot Returns: ------- fig : matplotlib.figure.Figure The Figure object. ax : matplotlib.axes.Axes The Axes object. """ # check ax and fig fig_settings_names = ['figsize', 'layout'] fig_settings = {key: setup_settings[key] for key in fig_settings_names if key in setup_settings} if ax is None and fig is None: fig, ax = plt.subplots(**fig_settings) elif ax is None or fig is None: raise ValueError("Both ax and fig must be provided or none of them.") else: # raise TypeError if ax is not an Axes object if not isinstance(ax, plt.Axes): raise TypeError("ax must be a matplotlib.axes.Axes object.") # raise TypeError if fig is not a Figure object if not isinstance(fig, plt.Figure): raise TypeError("fig must be a matplotlib.figure.Figure object.") ax.set_title(title) return fig, ax def _setup_labels(self, ax, xlabel: str, ylabel: str, is_sec_axis: bool = False, xlabel_settings: dict = None, ylabel_settings: dict = None, **label_settings): """ Set up the labels for the x and y axes. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot. xlabel : str The label for the x-axis. ylabel : str The label for the y-axis. is_sec_axis : bool, optional Indicates, whether the given ax object is a secondary axis (needed for ylabel position) xlabel_settings : dict, optional Additional settings for the x-axis label. ylabel_settings : dict, optional Additional settings for the y-axis label. label_settings : dict Returns ------- None """ # Create a copy to not change the original settings dict label_settings_ = label_settings.copy() # Remove ylabel specific parameters from label_settings ylabel_exclusives = ['hor_ylabels', 'horizontal_ylabels'] custom_ylabel_settings = {key: label_settings_.pop(key) for key in ylabel_exclusives if (key in label_settings_ or key.startswith('ylabel'))} if not is_sec_axis: # Setting xlabel only for primary axis # Updating general label settings with xlabel specifics. xlabel specifics take precedence. xlabel_settings = {**label_settings_, **xlabel_settings} ax.set_xlabel(xlabel, **xlabel_settings) else: ax.set_xlabel(xlabel) # Setting ylabel # optional settings for horizontal y labels # Check for and remove the 'hor_ylabels' and 'horizontal_ylabels' settings if custom_ylabel_settings.pop('hor_ylabels', False) or custom_ylabel_settings.pop('horizontal_ylabels', False): ylabel_settings.update({'rotation':'horizontal', 'rotation_mode':"anchor", 'verticalalignment':'baseline', 'ha':'left'}) ax.yaxis.set_label_coords(-0.05, 1.02) if is_sec_axis: ax.yaxis.set_label_coords(1.05, 1.02) ylabel_settings.update({'ha':'right'}) ylabel_settings = {**label_settings_, **ylabel_settings} ax.set_ylabel(ylabel, **ylabel_settings)
[docs] def get_last_context(self): """ Get the last style context used for plotting. Returns ------- dict The last style context used for plotting. """ if self._last_context is None: raise ValueError("No last context exists. Last context available after a plot call.") return self._last_context.copy()
[docs] def create_legend(ax, loc='best', labels=None, **kwargs): """ Creates a (fancy) legend for the plot. This function is a wrapper around the matplotlib legend function. It removes the default legend(s) and creates a new one. Improvements over the matplotlib legend function are: - The location specifier 'outside' can be used to place the legend outside the plot area. - The direction of the plot entries is inferred (if not controlled via ncols). E.g. spreading the entries horizontally if places above or below the plot. - Legend for two y axes is supported. Entries of both axis are put into the same legend. Subtitles for the two axes are supported (can be passed via 'title'-kwarg as tuple of two strings). The entries of the two axis are spread over the columns of the legend. Parameters ---------- ax : matplotlib.axes.Axes The axes containing the plot. loc : str, optional The location of the legend. Default is 'best'. labels : list, dict optional List of labels for the legend. If not provided, the labels of the plot will be used. If a dictionary is provided, the dictionary is used to update the existing labels. (Not implemented yet) **kwargs : dict Additional keyword arguments to pass to the legend. Returns ------- leg : matplotlib.legend.Legend The legend object. """ leg = None # Remove the default legend(s) if isinstance(ax, plt.Axes): leg = _create_single_legend(ax, loc, labels, **kwargs) elif isinstance(ax, tuple) and all(isinstance(a, plt.Axes) for a in ax): if len(ax) > 2: raise NotImplementedError("Legend creation is only supported for two axes at"+ " the moment. A tuple of more than 2 axes was passed.") leg = _create_double_legend(ax, loc, labels, **kwargs) else: raise TypeError("ax must be a matplotlib.axes.Axes object or a tuple of Axes objects.") return leg
def _create_double_legend(axs, loc, labels, **kwargs): """ Create a legend for a plot with two y axes. Parameters ---------- ax : tuple of matplotlib.axes.Axes The axes containing the plot. loc : str The location of the legend. Default is 'best'. Possible location specifiers are center, lower, upper, right and left and combinations. 'best' (Axes only), 'upper right', 'upper left', 'lower left', 'lower right', 'right', 'center left', 'center right', 'lower center', 'upper center', 'center' See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html for more info Additionally legend placement outside the plot area is possible with: 'outside' as an additional specifier 'This is done by adding 'outside' to the location specifier, e.g. 'outside upper right'. To the above combinations, the order of left/right and upper/lower can be changed to have the legend outside where the first specifier signals the side of the plot area and the second argument the alignment on that respective side. See https://matplotlib.org/stable/users/explain/axes/legend_guide.html for more info labels : list, dict List of labels for the legend. If not provided, the labels of the plot will be used. If a dictionary is provided, the dictionary is used to update the existing labels. **kwargs : dict Additional keyword arguments to pass to the legend. Returns ------- leg : matplotlib.legend.Legend The legend object. """ # Remove the legend if present # Remove the legend from the figure (if it was a figure legend) leg_fig = [c for c in axs[0].get_figure().get_children() if isinstance(c, matplotlib.legend.Legend)] for l in leg_fig: l.remove() for ax in axs: #Remove the legend from the axs try: ax.get_legend().remove() # Remove if present except AttributeError: pass # See if ncol is explicitely passed if 'ncols' in kwargs: ncols = kwargs.pop('ncols') elif 'ncol' in kwargs: warnings.warn('"ncol" was passed as a parameter in legend. "ncol" is deprecated.'+ ' Use "ncols" instead.') ncols = kwargs.pop('ncol') elif loc.startswith('outside right') or loc.startswith('outside left'): ncols = 1 else: ncols = 2 # Default for two y axes # Prepare and draw the legend h1, l1 = axs[0].get_legend_handles_labels() h2, l2 = axs[1].get_legend_handles_labels() # update labels if explicitely given if labels is not None: if isinstance(labels, list): if len(labels) == len(l1) + len(l2): # Update the labels l1 = labels[:len(l1)] l2 = labels[len(l1):] else: warnings.warn(f"The number of provided labels ({len(labels)}) does not match the "+ f"total number of legend entries ({len(l1) + len(l2)}). "+ "The labels will not be updated.") # TODO: update labels based on the passed parameter elif isinstance(labels, dict): l1 = [labels.get(label, label) for label in l1] l2 = [labels.get(label, label) for label in l2] # warnings.warn("Dictionary updating of legend labels not implemented yet."+ # "The dict is ignored.") else: warnings.warn("Labels are not a list. The labels will not be updated.") # Divide the legend entries among the cols cols_1 = 1 # Starting number of columns for primary axis entries if ncols == 1: n_rows = len(h1) + len(h2) elif ncols == 2: n_rows = max(len(h1), len(h2)) else: # Divide the legend entries among a col_number greater than 2 n_rows = 10**6 # A very large number while cols_1 < ncols: n_rows_1 = -(len(h1)//-cols_1) n_rows_2 = -(len(h2)//-(ncols-cols_1)) n_rows_new = max(n_rows_1, n_rows_2) # n_row_list.append(-(len(h1)//-cols_1) + -(len(h2)//-(ncols-cols_1))) # Note: // is the floor division perator the negation makes it ceil division if n_rows_new > n_rows: # Minimum was found break n_rows = n_rows_new cols_1 += 1 cols_1 = cols_1 - 1 # Get the legend section titles if 'title' in kwargs and isinstance(kwargs['title'], (tuple, list)) and len(kwargs['title'])==2: title = kwargs.pop('title') else: title = ('Left axis:', 'Right axis:') # Fill the cols with the legend entries if ncols == 1: handles = [title[0]] + h1 + [title[1]] + h2 labels = [''] + l1 + [''] + l2 else: handles = [] labels = [] col = 0 which = 0 while col < ncols: if col == cols_1: which = 1 if which == 0: h = h1 l = l1 else: h = h2 l = l2 if col in (0, cols_1): # Add the Axis label # handles.append(plt.Line2D([0], [0], color='none')) # labels.append(which+' axis:') handles.append(title[which]) labels.append('') else: # Add a dummy line for the consecutive columns handles.append(plt.Line2D([0], [0], color='none')) labels.append('') added_rows = 0 while added_rows<n_rows: # Fill the rows if h and l: #If there are more entries to get # Get the entries handles.append(h.pop(0)) labels.append(l.pop(0)) else: # Fill them with dummy lines handles.append(plt.Line2D([0], [0], color='none')) labels.append('') added_rows += 1 col += 1 class LegendTitle(object): """ A class to create a custom legend (sub)title for matplotlib plots. This title is implemented as an entry into the legend, where the text is shown at handle position Attributes: ----------- text_props : dict, optional A dictionary of text properties to customize the appearance of the legend title. Returns: -------- title : matplotlib.text.Text The text artist representing the legend title. """ def __init__(self, text_props=None): self.text_props = text_props or {} super(LegendTitle, self).__init__() def legend_artist(self, legend, orig_handle, fontsize, handlebox): x0, y0 = handlebox.xdescent, handlebox.ydescent title = matplotlib.text.Text(x0, y0, orig_handle, **self.text_props) handlebox.add_artist(title) return title # Create the legend leg = None # Get the fontsize for legend entries (this is necessary for the title to have the same size) fontsize = kwargs.get('fontsize', plt.rcParams['legend.fontsize']) if loc.startswith('outside'): # Turn the frame off, unless explicitely stated otherwise kwargs['frameon'] = kwargs.get('frameon', False) if 'mode' not in kwargs and loc.split(' ')[1] in ['upper', 'lower']: kwargs['mode'] = 'expand' if loc.split(' ')[1] in ['left', 'lower', 'right']: # in that case a figure legend is necessary to have the legend properly outside leg = axs[1].get_figure().legend(handles, labels, ncols=ncols, loc=loc, handler_map={str: LegendTitle({'fontsize':fontsize})}, **kwargs) else: loc = _get_legend_loc(loc, kwargs.pop('bbox_to_anchor', None)) leg = axs[1].legend(handles, labels, ncols=ncols, **loc, handler_map={str: LegendTitle({'fontsize':fontsize})}, **kwargs) else: leg = axs[1].legend(handles, labels, ncols=ncols, loc=loc, handler_map={str: LegendTitle({'fontsize':fontsize})}, **kwargs) return leg def _create_single_legend(ax, loc, labels, **kwargs): """ Create a legend for a plot with a single y axis. Parameters ---------- ax : matplotlib.axes.Axes The axes containing the plot. loc : str The location of the legend. labels : list, dict optional List of labels for the legend. If not provided, the labels of the plot will be used. If a dictionary is provided, the dictionary is used to update the existing labels. **kwargs : dict Additional keyword arguments to pass to the legend. Returns ------- leg : matplotlib.legend.Legend The legend object. """ if isinstance(labels, dict): # If a dictionary is provided, the labels are updated with the dictionary l_old = ax.get_legend_handles_labels()[1] l_new = [labels.get(label, label) for label in l_old] labels = l_new # warnings.warn("Dictionary updating of legend labels not implemented yet."+ # "The dict is ignored.") # labels=None leg_fig = [c for c in ax.get_figure().get_children() if isinstance(c, matplotlib.legend.Legend)] for l in leg_fig: l.remove() try: ax.get_legend().remove() # Remove if present except AttributeError: pass if loc.startswith('outside'): # Adjust the direction of legend entries if 'ncols' not in kwargs and loc[8:13] in ['lower', 'upper']: kwargs['ncols'] = len(ax.get_legend_handles_labels()[0]) # Turn the frame off, unless explicitely stated otherwise kwargs['frameon'] = kwargs.get('frameon', False) if loc[8:12] in ['left', 'lowe']: # in that case a figure legend is necessary to have the legend outside leg = ax.get_figure().legend(labels=labels, loc=loc, **kwargs) else: loc = _get_legend_loc(loc, kwargs.pop('bbox_to_anchor', None)) leg = ax.legend(labels=labels, **loc, **kwargs) else: leg = ax.legend(labels=labels, loc=loc, **kwargs) return leg def _get_legend_loc(loc, bbox_to_anchor=None): """ Get the location of the legend. Parameters ---------- loc : str The location of the legend. bbox_to_anchor : tuple, optional The bbox_to_anchor of the legend. Default is None. Returns ------- dict A dict containing the location of the legend. In 'normal' cases this is a dict with the key 'loc' and the value loc. In 'outside' cases this is a dict with the key 'loc' and the key 'bbox_to_anchor' with their specific values. """ custom_locs = {'outside upper left': {'loc': 'lower left', 'bbox_to_anchor': (0, 1, 1, 1)}, 'outside upper right': {'loc': 'lower right', 'bbox_to_anchor': (0, 1, 1, 1)}, 'outside upper center': {'loc': 'lower center', 'bbox_to_anchor': (0, 1, 1, 1)}, 'outside lower left': {'loc': 'upper left', 'bbox_to_anchor': (0, 0, 1, 0)}, 'outside lower right': {'loc': 'upper right', 'bbox_to_anchor': (0, 0, 1, 0)}, 'outside lower center': {'loc': 'upper center', 'bbox_to_anchor': (0, 0, 1, 0)}, 'outside right upper': {'loc': 'upper left', 'bbox_to_anchor': (1, 1)}, 'outside right lower': {'loc': 'lower left', 'bbox_to_anchor': (1, 0)}, 'outside right center': {'loc': 'center left', 'bbox_to_anchor': (1, 0.5)}, 'outside left upper': {'loc': 'upper right', 'bbox_to_anchor': (0, 1)}, 'outside left lower': {'loc': 'lower right', 'bbox_to_anchor': (0, 0)}, 'outside left center': {'loc': 'center right', 'bbox_to_anchor': (0, 0.5)}} if isinstance(loc, dict): pass elif bbox_to_anchor is not None: loc = {'loc': loc, 'bbox_to_anchor': bbox_to_anchor} # check if loc is a custom location elif loc in custom_locs: loc = custom_locs[loc] # loc['bbox_transform'] = plt.figure(fig.number).transFigure else: # Wrap in a dict loc = {'loc': loc} return loc ### HELPER FUNCTIONS ### def _extract_kwargs_with_prefix(kwargs, prefix): """ Extract keyword arguments from kwargs that start with a certain prefix. The keyword arguments are removed from the kwargs. (The kwargs-object is changed) Parameters ---------- kwargs : dict Dictionary containing keyword arguments. prefix : str, list, tuple The prefix to search for in the keys of the kwargs. Returns ------- dict Dictionary containing the keyword arguments that start with the prefix. """ if isinstance(prefix, tuple): prefix = list(prefix) elif isinstance(prefix, str): prefix = [prefix] i = 0 while i < len(prefix): if not prefix[i].endswith('_'): prefix[i] = prefix[i] + '_' # add an underscore at the end i += 1 prefix = tuple(prefix) extracted_settings = {key.split('_', maxsplit=1)[1]: kwargs.pop(key) for key in kwargs.copy().keys() if key.startswith(prefix)} return extracted_settings def _find_longest_match(searchlist, match, mode='start'): """ Find element from searchlist that has longest match with match: - the element from searchlist must be totally in match - the mathing starts either at the start (default) or at the end of the element from searchlist, this is using .startswith() or .endswith() respectively. This method is used to infer labels for the axes: consider: searchlist = ['Q','Q_flow','Q_flow_loss'] and - match1 = 'Q_flow_loss_xy' => returns 'Q_flow_loss' - match2 = 'Q_flow_loss' => returns 'Q_flow_loss' - match3 = 'Q_flow_xy' => returns 'Q_flow' - match4 = 'Q_xy' => returns 'Q' Parameters ---------- searchlist : list list of strings to search longest match in. match : str search string. mode : str whether to match from 'start' or 'end' of the items in searchlist Returns ------- item : str longest match if any, None otherwise. """ searchlist_revsorted = list(searchlist).copy() searchlist_revsorted.sort(key=len,reverse=True) match = str(match) # cast match to string for item in searchlist_revsorted: if mode == 'start': if match.startswith(item): return item elif mode == 'end': if match.endswith(item): return item else: raise ValueError("'mode' must be one of 'start' or 'end'. You provided '"+mode+"'.") return None def _get_valid_plot_kinds_pandas(): """ Get the valid plot kinds from the pandas documentation. This is done by parsing the docstring of the pandas plot function. The docstring is split into lines and the lines that contain the valid kinds are extracted. Parameters ---------- None Returns ------- valid_kinds : list List of valid plot kinds. """ # Get the docstring of the pandas plot function docstring_lines = pd.DataFrame().plot.__doc__.split('\n') docstring_lines = [line.strip() for line in docstring_lines if line.strip() != ''] # find the lines that contain the valid kinds => lines after 'kind : str', starting with "-" # and before the next line that starts not with "-" for i, line in enumerate(docstring_lines): if 'kind : str' in line: for j in range(i+1, len(docstring_lines)): if docstring_lines[j].startswith('-'): start_line = j break for j in range(start_line, len(docstring_lines)): if not docstring_lines[j].startswith('-'): end_line = j break valid_kinds = [] # extract the valid kinds from the lines between start_line and end_line # The lines start with "-" and contain the kind in the form of "- 'kind'" for line in docstring_lines[start_line:end_line]: if line.startswith('-'): kind = line.split("'")[1] valid_kinds.append(kind) return valid_kinds def _clip_fig(fig): """ Clip the figure to its tight bounding box and adjust the size accordingly. This is useful when the figure has axes with a fixed aspect ratio, to remove white margins. If no axes have a fixed aspect ratio, no clipping is needed. """ # check if any of the axes in the figure have a fixed aspect is_fixed = any(ax.get_aspect() != 'auto' for ax in fig.axes) if not is_fixed: # If no axes have a fixed aspect, no clipping is needed warnings.warn("No axes with fixed aspect ratio found. No clipping needed. If you want to " \ "change the figure size, you can use the parameter figsize=(width, height) in the call " \ "of plotter.plot().") return # Check if colorbars are present and more than one subplot (clipping might not work well then) has_colorbar = any(hasattr(ax, '_colorbar') for ax in fig.axes) if has_colorbar and len(fig.axes) > 2: warnings.warn("Clipping might not work well with colorbars and multiple subplots. Maybe " \ "you must adjust the figsize manually to fit your needs, e.g. by adding " \ "the parameter figsize = (width, height) (in inches) to the call of " \ "plotter.plot().") # Ensure the figure is drawn so that the actual layout is done to get the correct bbox fig.canvas.draw() # Get the tight bounding box bbox = fig.get_tightbbox(fig.canvas.get_renderer()) width, height = bbox.width, bbox.height # Get the factor to maintain the figure dimension that must not be trimmed # (otherwise the resulting figure would be smaller than before) factor = min(fig.get_size_inches()[0] / width, fig.get_size_inches()[1] / height) # Do the actual clipping fig.set_size_inches(width * factor, height * factor) # draw the figure again to apply the new size fig.canvas.draw() ### EXAMPLE USAGE ###
[docs] def example(example_types=['all']): """ Example usage of the Plotter class. Parameters ---------- example_types : list, optional List of plot types to demonstrate. Default is ['all']. Possible values are: 'all', 'style_examples', 'pd_kinds', 'multi_scatter', 'R2_plots', 'sec_axis', 'subplots', 'cycler'. """ # Create a sample DataFrame data_list = list(range(48)) data = {'T_1': data_list, 'T_2': [sqrt(x) for x in data_list], 'T_3': [sqrt(x)+5 for x in data_list], 'm_flow': [10+x for x in data_list],} hours = pd.date_range(start="2023-01-01", end="2023-01-02 23:00", freq="1h") df = pd.DataFrame(data, index=hours) # Create an instance of Plotter plotter = Plotter() plotter_presentation = Plotter(style_default='presentation') plotter_paper = Plotter(style_default='paper') # demonstrate different styles, behavior without specifing ycols if 'all' in example_types or 'style_examples' in example_types: plotter.plot(df) plotter_presentation.plot(df) plotter.plot(df, style='presentation') plotter.plot(df, style='paper') # behavior with specifing title as 'infer' plotter.plot(df, style='paper', title='infer') # behavior with specifing ycols plotter.plot(df, style='paper', ycols=['T_1','T_2']) # behavior with specifing ycols and ylabel: plotter.plot(df, ycols=['T_1','T_2'], ylabel='Temperatur', title='Custom Title') plotter.plot(df, ycols=['T_1','T_4']) # specify a non-existing column # Plot data using plot_general method with different kind of plots if 'all' in example_types or 'pd_kinds' in example_types: plotter.plot(df, ycols=['T_1','T_2'], title='Line plot', kind = 'line') plotter.plot(df, xcols= 'T_1', ycols='T_2', title='Bar plot', kind = 'bar') plotter.plot(df, xcols= 'T_1', ycols='T_2', title='Hor. Bar plot', kind = 'barh') plotter.plot(df, ycols=['T_1','T_2'], title='Hist plot', kind = 'hist', alpha=0.5) plotter.plot(df,ycols=['T_1','T_2'], title='Box plot', kind = 'box') plotter.plot(df, ycols=['T_1','T_2'], title='KDE plot', kind = 'kde') plotter.plot(df, ycols=['T_1','T_2'], title='Density plot', kind = 'density') plotter.plot(df, ycols=['T_1','T_2'], title='Area plot', kind = 'area') plotter.plot(df.iloc[0:5,:], ycols='T_1', title='Pie plot', kind = 'pie') plotter.plot(df, title='Scatter plot', kind = 'scatter', xcols= 'T_1', ycols='T_2', s='T_1', c='m_flow', marker='x', linestyle='solid') plotter.plot(df, xcols= 'T_1', ycols='T_2', title='Hexbin plot', kind = 'hexbin') # examples with multi_scatter if 'all' in example_types or 'multi_scatter' in example_types: plotter.plot(df, xcols= 'T_1', ycols=['T_2','T_3'], title='Multi scatter plot', kind = 'multi_scatter', legend_loc='outside upper left', marker='cycle') plotter.plot(df, xcols= ['T_1','T_2'], ycols=['T_2','T_3'], title='Multi scatter plot', kind = 'multi_scatter', leg_labels = ['T_2 over T_1', 'T_3 over T_2'], marker=['x', 'o'], legend_loc='lower right') # examples with R2 plots if 'all' in example_types or 'R2_plots' in example_types: # simple R2 plots plotter.plot(df, xcols= 'T_2', ycols='T_2', title='R2 scatter plot', kind = 'R2_scatter') plotter.plot(df, xcols= 'T_2', ycols='T_2', title='R2 hexbin plot', kind = 'R2_hexbin') # examples with more customization plotter_presentation.plot(df, xcols= 'T_2', ycols='T_2', kind = 'R2_scatter', title='R2 scatter plot, customized, presentation', metrics=False, marker='x', color='red', alpha=0.5, s='T_1', xlabel='measured temperature', ylabel='modeled temperature', xlim=[0,10], ylim=[0,10]) plotter.plot(df, xcols= 'T_2', ycols='T_2', kind = 'R2_hexbin', title='R2 hexbin plot: customized, paper', metrics = ['R2','RMSE'], cmap='Reds', gridsize=20, xlabel='measured temperature', ylabel='modeled temperature', xlim=[0,10], ylim=[0,10], style = 'paper') # examples with secondary axis if 'all' in example_types or 'sec_axis' in example_types: plotter_presentation.plot( df, title='Timeseries with secondary y-axis (and outside legend)', ycols=['m_flow'], sec_ycols=['T_1','T_2'], sec_marker='x', sec_linestyle='', legend_loc='outside right upper' ) # same with horizontal y labels plotter_presentation.plot( df, title='', lab_hor_ylabels = True, ycols=['m_flow'], sec_ycols=['T_1','T_2'], sec_marker='x', sec_linestyle='', legend_loc='outside right upper' ) plotter.plot(df, kind = 'multi_scatter', title='Multi-Scatter plot', xcols= 'T_1', ycols='T_2', xlim = [10,20], sec_xcols = 'T_1', sec_ycols='m_flow', legend_loc='outside lower left', legend_ncols=7) # Note that xcols and sec_xcols are different physical quantities # => no xlabel will be printed, unless defined explicitely plotter.plot(df, xcols= ['T_1','T_2'], ycols=['T_2','T_3'], sec_xcols = 'm_flow', sec_ycols='T_2', title='Multi scatter plot', kind = 'multi_scatter', leg_labels = ['T_2 over T_1', 'T_3 over T_2', 'T_2 over m_flow'], legend_loc='lower right', #xlabel = "T or m_flow" ) # examples with subplots if 'all' in example_types or 'subplots' in example_types: plots = { 'Plot a timeseries': {'ycols':['T_1', 'T_2'], 'kind':'line'}, 'Histogram': {'ycols':'m_flow', 'kind':'hist'} } plotter.plot_subplots(df, plots) plotter.plot_subplots(df, plots, layout=(1,2), figsize=(8,4), wspace=0) # behavior with more than 3 plots => layout is inferred plots_2 = { 'Plot a timeseries': {'ycols':['T_3', 'T_2'], 'kind':'line'}, 'Histogram': {'ycols':'T_1', 'kind':'hist'}, 'Histogram 2': {'ycols':'T_2', 'kind':'hist', 'title':None}, 'Histogram 3': {'ycols':'m_flow','color':'purple','bins':20, 'grid':False, 'kind':'hist'} } plotter.plot_subplots(df, plots_2) # behavior with more than 3 plots => layout is inferred # Example on cycler usage if 'all' in example_types or 'cycler' in example_types: # Create a sample DataFrame x = np.linspace(0, 3*np.pi, 20) y = {} for i in range(10): y[f'line_{i}'] = np.sin(x) + i/2 df2 = pd.DataFrame(y, index=x) # Create an instance of Plotter plotter_cycler = Plotter(style_default='presentation') # Plot data using plot_data method # No arguments plotter_cycler.plot(df2, title='Cycler example without arguments') # Different lengths arguments for linestyle, marker and color plotter_cycler.plot(df2, title='Cycler example with different lengths props', linestyle=['dashed', 'dotted'], marker=['o', 'x', 's'], color=['red', 'blue', 'green', 'orange']) # With secondary axis plotter_cycler.plot(df2, ycols=list(df2.columns[:5]), sec_ycols=list(df2.columns[6:]), title='Cycler example with secondary axis', linestyle=['dashed', 'dotted'], ylim=[-1,6], marker=['o', 'x', 's'], sec_linestyle='', color=['red', 'blue', 'green', 'orange']) # Using the 'cycle' keyword plotter_cycler.plot(df2, ycols=list(df2.columns[:5]), sec_ycols=list(df2.columns[6:]), title='Cycler example with secondary axis and linestyle "cycle"', linestyle='cycle', ylim=[-1,6], color='cycle', sec_color='red') plotter_cycler.plot(df2, ycols=list(df2.columns[:5]), sec_ycols=list(df2.columns[6:]), title='Cycler example with secondary axis and marker "cycle"', linestyle='', marker='cycle', ylim=[-1,6], color='cycle', sec_color='red', markersize=[10,15]) # Using predefined cycles plotter_cycler.plot(df2, title='Cycler example with "filled" marker, "main" linestyle and "rgb" color', marker='filled', linestyle='main', color='rgb')