"""
This module contains the Plotter class, which is designed for creating and managing plots based on
data from pandas DataFrames.
It provides functionalities to:
- initialize with a DataFrame, which is used for plotting
- plot data from the DataFrame (different types of plots)
- store and manage the plots
"""
from math import sqrt
from typing import Optional, Union, List
import warnings
from copy import deepcopy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from sattoolbox.plots import cyclers
[docs]
class Plotter:
"""
A class for convenient plotting from pandas DataFrames with matplotlib.
The plot(kind='...') method allows for plotting any kind of plot supported by
pandas.DataFrame.plot() and additional custom kinds (currently available: 'multi_scatter',
'R2_scatter, 'R2_hebin'). The method takes care of setting up the figure, axes, and styles, and
allows for customization of the plot through keyword arguments.
The plot_subplots method allows to plot multiple subplots in a single figure.
SAT custom styles are available for different use cases (screen, presentation, paper).
The styles all use the sat_base style and extend or ovveride from that. The base style
mostly defines the default color cycle, and the setup of the axis and ticks.
Presentation Styles: (Optimized size and fontsize (18) for presentations)
- 'presentation': style for plots filling the whole width of the presentation slide.
- 'presentation_half': style for plots filling half the width (2 plots or plot plus text)
Paper style sheets:
- Reduced font size: 10
- Linewidth, Markersize and ticks-size reduced
- Various Paper Styles:
- Sizes: regular (1.5 col --> 5.5 inch), half --> 5.5 inch, full --> 7.4 inch
(see Elsevier author guidelines)
- Black&White Paper (bw_paper): Color (3 shades of grey/black) and Linestyle Cycle changed,
Grid turned off
Screen Style:
- Size fitted to be almost fullscreen on a Full HD Screen
- Facecolor: White
- Fontsize: 20
- Both minor and Major Grid activated
Available styles are:
- paper
- paper_half
- paper_full
- bw_paper
- bw_paper_half
- bw_paper_full
- screen
- presentation
- presentation_half
Attributes:
physical_quantities (pd.DataFrame): DataFrame containing information about physical quantities.
plot_kinds_pd (list): List of valid plot kinds from pandas.
plot_kinds_custom (list): List of custom plot kinds defined in this class.
kind_specifics (dict): Dictionary containing specifics for each plot kind.
sat_styles (list): List of available SAT styles.
style_default (str): Default style for the plots.
"""
# Plot kind specifics:
# - set special defaults
# - Check if the combination of kind and xcols / ycols makes sense
# - adjust the "kind_to_plot" for custom kinds (e.g. multi_scatter uses line plot)
kind_specifics = {
'line': {'uses_cycler':True},
'bar': {},
'barh': {},
'hist': {
'xcols_check': {'type' : type(None),
'message': 'xcols must be None for histograms (do not specify)'},
},
'box': {
'defaults': {'legend': False},
'xcols_check': {'type' : type(None),
'message': 'xcols must be None for box plot (do not specify)'},},
'kde': {'uses_cycler':True},
'density': {'uses_cycler':True},
'area': {
'uses_cycler':True,
'xcols_check': {'type' : (type(None),str),
'message': 'xcols must be None or a string for area plot.'},
},
'pie': {
'xcols_check': {'type' : type(None),
'message': 'xcols must be None for pie plot (do not specify)'},
'ycols_check': {'type' : str,
'message': 'ycols must be a string for pie plot.'},
},
'scatter': {
'defaults': {'legend': False},
'xcols_check': {'type' : str,
'message': ('xcols must be a string for scatter plot. To plot '+
"multiple ycols over multiple xcols use kind='multi_scatter'")},
'ycols_check': {'type' : str,
'message': ('ycols must be a string for scatter plot. To plot '+
"multiple ycols use kind='multi_scatter'")},
},
'hexbin': {
'defaults': {'grid': False, 'legend': False, 'mincnt':np.nextafter(0, 1),
'gridsize': 50},
'xcols_check': {'type' : str,
'message': 'xcols must be a string for hexbin plot.'},
'ycols_check': {'type' : str,
'message': 'ycols must be a string for hexbin plot.'},
},
'multi_scatter': {
'defaults': {'linestyle': '', 'marker': 'o'},
'kind_to_plot': 'line',
'xcols_check': {'type' : (str, list, tuple),
'message': ('xcols must be a string or a list of strings for '+
"multi_scatter plot.")},
'ycols_check': {'type' : (str, list, tuple),
'message': ('ycols must be a string or a list of strings for '+
"multi_scatter plot.")},
},
'R2_scatter': {
'defaults': {'legend': False},
'xcols_check': {'type' : str,
'message': ('xcols must be a string for R2_scatter plot. To plot '+
"multiple ycols over multiple xcols use kind='R2_multi_scatter'")},
'ycols_check': {'type' : str,
'message': ('ycols must be a string for R2_scatter plot. To plot '+
"multiple ycols use kind='R2_multi_scatter'")},
},
'R2_hexbin': {
'defaults': {'grid': False, 'legend': False, 'mincnt':np.nextafter(0, 1),
'gridsize': 50},
'xcols_check': {'type' : str,
'message': 'xcols must be a string for R2_hexbin plot.'},
'ycols_check': {'type' : str,
'message': 'ycols must be a string for R2_hexbin plot.'},
},
}
# available SAT styles
# defined as class attributes, so they can be accessed directly from outside
sat_styles = ['screen', 'presentation', 'presentation_half',
'paper', 'paper_half', 'paper_full',
'bw_paper', 'bw_paper_half', 'bw_paper_full',]
def __init__(self,
physical_quantities: Optional[pd.DataFrame] = None,
plot_settings: Optional[dict] = None,
style_default: Optional[str] = 'screen'
) -> None:
"""
Initialize the Plotter object.
Parameters
----------
physical_quantities : Optional[pd.DataFrame], optional
Optional pandas DataFrame containing information about physical quantities, by default
None.
The DataFrame must have columns: 'label', 'color', 'symbol', 'unit'.
If not provided, a default DataFrame will be used.
plot_settings : Optional[dict], optional
Optional dictionary containing default plot settings, by default None.
Whatever is not provided, will be taken from a default dictionary.
Note: These are settings, that are NOT part of plt.rcParams (those are controlled
via the style parameters or set expliticely in the plot method).
style_default : Optional[str], optional
Optional string specifying the style of the plots, by default 'screen'.
Available styles are: 'screen', 'presentation', 'paper' and any style from
plt.style.available.
Returns
-------
None
"""
self._set_physical_quantities(physical_quantities)
self._set_default_plot_settings(plot_settings)
self.style_default = self._style_check_n_correct(style_default)
self._last_context = None # to keep track of the last style context used
self.plot_kinds_pd = _get_valid_plot_kinds_pandas()
self.plot_kinds_custom = ['multi_scatter', 'R2_scatter', 'R2_hexbin'] # custom plot types that are not part of pandas
#TODO Put this in a better place. This is a strange place to change the kind specifics dict
Plotter.kind_specifics['multi_scatter']['plot_func'] = self._plot_multi_scatter
Plotter.kind_specifics['R2_scatter']['plot_func'] = self._plot_R2_scatter
Plotter.kind_specifics['R2_hexbin']['plot_func'] = self._plot_R2_hexbin
# Check if custom kinds are valid (they should not overwrite the valid_kinds from pandas
for custom_kind in self.plot_kinds_custom:
if custom_kind in self.plot_kinds_pd:
raise ValueError((f"Invalid custom plot kind '{custom_kind}'."+
f"Must NOT be one of the pandas kinds: {self.plot_kinds_custom}."))
### PUBLIC METHODS ###
# Note towards "typing": Union[] might be deprecated, at least as of python 3.9 it is possible
# to write str | list[str]
[docs]
def plot(self,
df: pd.DataFrame,
kind: Optional[str] = 'line',
xcols: Optional[Union[str, List[str]]] = None,
xlabel: Optional[str] = None,
ycols: Optional[Union[str, List[str]]] = None,
ylabel: Optional[str] = None,
ax : Optional[plt.Axes] = None,
fig : Optional[plt.Figure] = None,
style : Optional[str] = None,
**kwargs
) -> tuple[plt.Figure, Union[plt.Axes, tuple[plt.Axes]]]:
"""
Plot from a DataFrame using pandas DataFrame.plot().
This method adds some convenience features to the pandas plot method, such as handling of
styles, titles, and legends.
It is a general method for plotting and can be used for any kind of plot supported by pandas
(see Plotter.plot_kinds_pd) plus some additional plot types (see Plotter.plot_kinds_custom).
Parameters
----------
df : pd.DataFrame
DataFrame containing the data to plot.
kind : str, optional
Type of plot to create. Check Plotter.plot_kinds_pd and Plotter.plot_kinds_custom for
available plot types.
Default is 'line'.
xcols : str or list of str, optional
List of the column(s) to use for the x-axis. If not provided, the index of the DataFrame
will be used as x-values.
If multiple ycols are provided, x_col must either be a list of the same lengths or a
single column name that will be used for all ycols.
xlabel : str, optional
Label for the x-axis. Default is None.
If None, the label will be inferred from the xcols name(s).
ycols : str or list of str, optional
Name of the column(s) to use for the y-axis. If not provided, all columns from the
DataFrame will be plotted.
ylabel : str, optional
Label for the y-axis. Default is None.
If None, the label will be inferred from the ycols name(s).
ax : matplotlib.axes.Axes, optional
The axes on which to plot. If not provided, a new figure and axes will be created.
fig : matplotlib.figure.Figure, optional
The figure on which to plot. If not provided, a new figure will be created.
style : str, optional
Style to use for the plot. If not provided, the default style of the Plotter instance
will be used.
**kwargs : dict
Additional keyword arguments to pass to the plot.
All keyword arguments accepted by pandas.DataFrame.plot are accepted.
Returns
-------
fig : matplotlib.figure.Figure
The figure on which has been plotted.
ax : matplotlib.axes.Axes or tuple of matplotlib.axes.Axes
The axes (or tuple of axes, if secondary axis is used) on which has been plotted.
"""
# Translate the old names for y_label, x_label, ... to the new ones
name_updates = {'y_label':'ylabel', 'x_label':'xlabel', 'y_cols':'ycols', 'x_cols':'xcols'}
for old_name in name_updates:
for key in kwargs.copy():
if old_name in key: # check if the key contains the old name
new_key = key.replace(old_name, name_updates[old_name])
if new_key not in kwargs:
warnings.warn(f"'{key}' is deprecated. Use '{new_key}' instead.")
if new_key == 'ylabel':
ylabel = kwargs.pop(key)
elif new_key == 'xlabel':
xlabel = kwargs.pop(key)
elif new_key == 'ycols':
ycols = kwargs.pop(key)
elif new_key == 'xcols':
xcols = kwargs.pop(key)
else:
kwargs[new_key] = kwargs.pop(key)
else:
raise ValueError(f"'{key}' and '{new_key}' are both provided."+
f"Please provide only '{new_key}'.")
# check if "clip_figure" is in kwargs and process it (only used to supress clipping when
# plotting subplots with plot_subplots())
clip_figure = kwargs.pop("clip_figure", True)
# check plot kind
if kind not in self.plot_kinds_custom + self.plot_kinds_pd:
raise ValueError(f"Invalid plot kind '{kind}'. Must be one of {self.plot_kinds_pd} " +
f"(plot kinds from pandas) or {self.plot_kinds_custom}")
kind_to_plot = kind
# apply kind_specific settings (from Class attribute kind_specifics)
if kind in Plotter.kind_specifics:
# Settings that are defaulted to something
if 'defaults' in Plotter.kind_specifics[kind]:
for key, value in Plotter.kind_specifics[kind]['defaults'].items():
if key not in kwargs:
kwargs[key] = value
# change kind_to_plot for custom kinds
kind_to_plot = Plotter.kind_specifics[kind].get('kind_to_plot', kind)
# check xcols type
if 'xcols_check' in Plotter.kind_specifics[kind]:
if not isinstance(xcols, Plotter.kind_specifics[kind]['xcols_check']['type']):
raise ValueError(Plotter.kind_specifics[kind]['xcols_check']['message'])
# check ycols type
if 'ycols_check' in Plotter.kind_specifics[kind]:
if not isinstance(ycols, Plotter.kind_specifics[kind]['ycols_check']['type']):
raise ValueError(Plotter.kind_specifics[kind]['ycols_check']['message'])
ycols = self._cols_check_n_correct(df, ycols)
ylabel = self._infer_label(ycols, ylabel, axis_context='the y-axis')
# keep provided xlabel (needed for check of xlabel with sec axis)
xlabel_provided = xlabel
# handling xcols
if xcols is not None:
xcols = self._cols_check_n_correct(df, xcols, required_lengths=[1, len(ycols)])
xlabel = self._infer_label(xcols, xlabel)
if isinstance(xcols, (list, tuple)) and len(xcols) == 1:
xcols = xcols[0]
else:
# xcols were not passed --> the index of the dataframe will be used as x-values
# Nothing to be done except getting the name
xlabel = self._infer_label([df.index.name] if df.index.name is not None else ['x'],
xlabel, axis_context='the x-axis')
# TODO (Idea): Infer colors for the lines from the physical_quantities DataFrame
# Check if secondary y-axis is used and extract the settings for it
sec_settings = self._get_secondary_ax_settings(kwargs, required_kwargs='ycols')
if sec_settings:
sec_ycols = self._cols_check_n_correct(df, sec_settings.pop('ycols'))
sec_ylabel = self._infer_label(sec_ycols, sec_settings.pop('ylabel', None),
axis_context='the secondary y-axis')
# Handle xcols for sec axis
if 'xcols' in sec_settings:
sec_xcols = sec_settings.pop('xcols')
else:# Default will be the same as for primary axis
sec_xcols = xcols
if sec_xcols is not None:
sec_xcols = self._cols_check_n_correct(df, sec_xcols,
required_lengths=[1, len(sec_ycols)])
sec_xlabel = sec_settings.pop('xlabel', xlabel_provided)
if sec_xlabel is None:
# This means there was no xlabel provided for either primary or sec axis
# Only then infer a label
sec_xlabel = self._infer_label(sec_xcols, None, axis_context='the secondary x-axis')
# else:
# sec_xlabel = sec_settings.pop('xlabel', xlabel)
if len(sec_xcols)==1:
sec_xcols = sec_xcols[0]
else: # neither xcols nor sec_xcols were passed --> df.index will be used as x-values
# check for sec_xlabel in sec_settings, otherwise use the primary xlabel
sec_xlabel = sec_settings.pop('xlabel', xlabel)
# Warn if sec_xlabel is inconsistent
if sec_xlabel != xlabel:# and not xlabel_provided:
warnings.warn((f"xlabel '{xlabel}' is inconsistent with sec_xlabel " +
f"'{sec_xlabel}'. Setting xlabel = None."))
xlabel = None
# take all kwargs that are in the form of rcParams and parse them to the correct form for
# rcParams
update_rcparams, parsed_rcparams = self._get_rcparams_from_kwargs(kwargs)
# handle plot style
style = self._style_check_n_correct(style)
# Take title from kwargs and call _infer_title
title = self._infer_title(kind.capitalize() + ' plot', ycols, style, kwargs.pop('title', 0))
# all steps that use rcParams (esp. all the plotting) should be done in this style context
with plt.style.context(style):
# override style with the rcParams settings provided in kwargs
with plt.rc_context(update_rcparams):
self._last_context = plt.rcParams.copy()
specific_kwargs = {key: value for key, value in kwargs.items()
if key not in update_rcparams and key not in parsed_rcparams}
setup_settings, plot_settings, other_kwargs = self._get_setup_and_plot_settings(
specific_kwargs)
legend_settings = _extract_kwargs_with_prefix(other_kwargs, ['legend', 'leg'])
label_settings = _extract_kwargs_with_prefix(other_kwargs, ['label', 'lab'])
xlabel_settings = _extract_kwargs_with_prefix(other_kwargs, ['xlabel'])
ylabel_settings = _extract_kwargs_with_prefix(other_kwargs, ['ylabel'])
# setup figure and axes
fig, ax = self._setup_plot(title=title, fig=fig, ax=ax, **setup_settings)
# Create and set the cycler for cycable properties (e.g. color, linestyle)
if Plotter.kind_specifics[kind_to_plot].get('uses_cycler', False):
cycle_dict = plt.rcParams['axes.prop_cycle'].by_key()
cycler, cycle_dict, other_kwargs = cyclers.create_cycle(cycle_dict,
**other_kwargs)
ax.set_prop_cycle(cycler)
# actual plotting
if len(ycols)==1:
# if only one y_col is provided, it is passed as a string
ycols = ycols[0]
plot_func = Plotter.kind_specifics[kind].get('plot_func')
if plot_func is not None:
plot_func(df, xcols=xcols, ycols=ycols, ax=ax,
**plot_settings, **other_kwargs)
else:
df.plot(x=xcols, y=ycols, kind=kind_to_plot, ax=ax,
**plot_settings, **other_kwargs)
# set labels
self._setup_labels(ax=ax, xlabel=xlabel, ylabel=ylabel, xlabel_settings=xlabel_settings,
ylabel_settings=ylabel_settings, **label_settings)
# Plot the secondary y-axis if settings are provided
if sec_settings:
# Assume that all settings passed for the 'normal' axis are also valid for the
# secondary axis --> copy them and update with sec_settings
update_rcparams, parsed_rcparams = self._get_rcparams_from_kwargs(sec_settings)
with plt.rc_context(update_rcparams):
sec_setup_settings, sec_plot_settings, sec_other_kwargs = (
self._get_setup_and_plot_settings(sec_settings))
sec_setup_settings = {**setup_settings, **sec_setup_settings}
sec_plot_settings = {**plot_settings, **sec_plot_settings}
sec_other_kwargs = {**other_kwargs, **sec_other_kwargs}
# Remove settings, that should not be repeated for the sec axis
# (legend, label)
_extract_kwargs_with_prefix(sec_other_kwargs, ['legend', 'leg'])
sec_label_settings = _extract_kwargs_with_prefix(sec_other_kwargs, ['label', 'lab'])
sec_xlabel_settings = _extract_kwargs_with_prefix(sec_other_kwargs, ['xlabel'])
sec_ylabel_settings = _extract_kwargs_with_prefix(sec_other_kwargs, ['ylabel'])
# setup figure and axes
ax = (ax, ax.twinx()) # create a secondary y-axis
self._setup_plot(title=None, fig=fig, ax=ax[1], **sec_setup_settings)
# Create and set the prop_cycler
if Plotter.kind_specifics[kind_to_plot].get('uses_cycler', False):
n_lines = len(ycols) if not isinstance(ycols, str) else 1
cycler, cycle_dict, other_kwargs = cyclers.create_cycle(cycle_dict,
n_lines=n_lines,
**sec_other_kwargs)
ax[1].set_prop_cycle(cycler)
# actual plotting on the secondary y-axis
if plot_func is not None:
plot_func(df, xcols=sec_xcols, ycols=sec_ycols, ax=ax[1],
**sec_plot_settings, **sec_other_kwargs)
else:
df.plot(x=sec_xcols, y=sec_ycols, kind=kind_to_plot, ax=ax[1],
**sec_plot_settings, **sec_other_kwargs)
# setup the labels for the secondary axis
self._setup_labels(ax=ax[1], xlabel=sec_xlabel, ylabel=sec_ylabel,
is_sec_axis=True, xlabel_settings=sec_xlabel_settings,
ylabel_settings=sec_ylabel_settings,
**dict(label_settings, **sec_label_settings))
# Post-Processing section
if plot_settings.get('legend', True):
create_legend(ax, **legend_settings)
# clip figure if no explicit figsize is given and if ax has a fixed aspect ratio
if clip_figure and not 'figsize' in kwargs.keys():
# handle primary and potential secondary axis
if isinstance(ax, plt.Axes):
has_fixed_aspect = ax.get_aspect() != 'auto'
else:
has_fixed_aspect = any(a.get_aspect() != 'auto' for a in ax)
if has_fixed_aspect:
_clip_fig(fig)
return fig, ax
def _plot_multi_scatter(self,
df: pd.DataFrame,
xcols,
ycols,
ax,
**kwargs
):
"""
Custom plot function for multi_scatter plots (that is scatter plots with more than one pair
of x and y involved).
This function is called from plot() if the kind is 'multi_scatter'.
Parameters
----------
df : pd.DataFrame
DataFrame containing the data to plot.
xcols : str or list of str
Name of the column(s) to use for the x-axis.
ycols : str or list of str
Name of the column(s) to use for the y-axis.
ax : matplotlib.axes.Axes
The axes on which to plot.
**kwargs : dict
Additional keyword arguments to pass to the plot (e.g., linestyle, linewidth, etc.).
"""
# Wrap the xcols and ycols in lists if they are not already
if isinstance(xcols, str):
xcols = [xcols]
if isinstance(ycols, str):
ycols = [ycols]
# Duplicate xcols for each y_col if only one x_col is provided
# (This needs xcols and ycols to be in lists)
xcols = xcols*len(ycols) if len(xcols) == 1 else xcols
# Construct a new DataFrame with the xcols as index and ycols as values
# This is done to allow for easy plotting of multiple ycols against the same x_col(s).
df_new = pd.concat([df.set_index(x_i)[y_i] for x_i, y_i
in list(zip(xcols, ycols))], axis=1)
# Keep xlabel (was set in setup_plot() but would be changed here if not given explicitely)
# xlabel = ax.get_xlabel()
# df_new.plot(y=ycols, kind='line', ax=ax, xlabel = xlabel, **kwargs)
df_new.plot(y=ycols, kind='line', ax=ax, **kwargs)
def _plot_R2_scatter(self,
df: pd.DataFrame,
xcols: str,
ycols: str,
ax: plt.Axes,
metrics: Optional[Union[str, List[str]]] = 'all',
**kwargs):
"""
Custom plot function for R2_scatter plots for comparison of modelled and measured values.
The plot is a scatter plot with a squared layout, an angle bisector line and the metrics R2,
MAE and RMSE are printed into the plot area.
This function is called from plot() if the kind is 'R2_scatter'.
Parameters
----------
df : pd.DataFrame
DataFrame containing the data to plot.
xcols : str
Name of the column to use for the x-axis (modelled values).
ycols : str
Name of the column to use for the y-axis (measured values).
ax : matplotlib.axes.Axes
The axis on which to plot.
metrics : str or list of str, optional
Metrics to display on the plot. Can be 'all', 'R2', 'MAE', 'RMSE' or a list of these.
To suppress the metrics, set to None or False.
**kwargs : dict
Additional keyword arguments to pass to the plot (e.g., markerstyle etc).
"""
ax.set_aspect('equal', adjustable='box') # set the size of the ax to be square
self._add_angle_bisector_line(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax)
if metrics:
self._add_metrics_to_R2_plot(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax,
metrics=metrics)
# plot the data
df.plot(x=xcols, y=ycols, kind='scatter', ax=ax, **kwargs)
def _plot_R2_hexbin(self,
df: pd.DataFrame,
xcols: str,
ycols: str,
ax: plt.Axes,
metrics: Optional[Union[str, List[str]]] = 'all',
**kwargs):
"""
Custom plot function for R2_hexbin plots for comparison of modelled and measured values.
The plot is a hexbin plot with a squared layout, an angle bisector line and the metrics
R2, MAE and RMSE are printed into the plot area.
This function is called from plot() if the kind is 'R2_hexbin'.
Parameters
----------
df : pd.DataFrame
DataFrame containing the data to plot.
xcols : str
Name of the column to use for the x-axis (modelled values).
ycols : str
Name of the column to use for the y-axis (measured values).
ax : matplotlib.axes.Axes
The axis on which to plot.
metrics : str or list of str, optional
Metrics to display on the plot. Can be 'all', 'R2', 'MAE', 'RMSE' or a list of these.
To suppress the metrics, set to None or False.
**kwargs : dict
Additional keyword arguments to pass to the plot (e.g., markerstyle etc).
"""
# set the size of the ax to be square
ax.set_aspect('equal', adjustable='box')
self._add_angle_bisector_line(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax)
if metrics:
self._add_metrics_to_R2_plot(x=df.loc[:,xcols], y=df.loc[:,ycols], ax=ax,
metrics=metrics)
# plot the data
df.plot(x=xcols, y=ycols, kind='hexbin', ax=ax, **kwargs)
def _add_angle_bisector_line(self, x, y, ax):
"""
Add an angle bisector line to the plot.
Parameters
----------
x : array-like
x-values of the data.
y : array-like
y-values of the data.
ax : matplotlib.axes.Axes
The axis on which to plot.
"""
min_val = min(x.min(), y.min())
max_val = max(x.max(), y.max())
ax.plot([min_val, max_val], [min_val, max_val], color='black', linestyle='--',
linewidth=1, label='Angle bisector')
def _add_metrics_to_R2_plot(self, x, y, ax, metrics = 'all'):
"""
Add metrics to the R2 plot.
Parameters
----------
x : array-like
x-values of the data.
y : array-like
y-values of the data.
ax : matplotlib.axes.Axes
The axis on which to plot.
metrics : str or list of str, optional
Metrics to display on the plot. Can be 'all', 'R2', 'MAE', 'RMSE' or a list of these.
Default is 'all'.
"""
try:
from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error
except ImportError:
warnings.warn("scikit-learn is required to print metrics on R2 plots. "+
"Please install it via 'pip install scikit-learn' or 'conda install "+
"scikit-learn'. Metrics are not printed now.")
return
# calculate metrics (R2, MAE, RMSE)
dict_metrics = {
'R2': r2_score(x, y),
'MAE': mean_absolute_error(x, y),
'RMSE': root_mean_squared_error(x, y)
}
if metrics == 'all':
metrics = dict_metrics.keys()
if isinstance(metrics, str):
metrics = [metrics]
metrics_text = ''
for metric in metrics:
if metric not in dict_metrics:
raise ValueError(f"Invalid metric '{metric}'. Must be one of {dict_metrics.keys()}")
metric_value = dict_metrics[metric]
metrics_text += f"{metric}: {metric_value:.2f}\n"
if metric == 'R2' and metric_value < 0:
message = (
f"R² = {metric_value:.2f} < 0. This indicates that the model is worse than a "
"horizontal line. For more information check the documentation of the r2_score "
"function at: "
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html"
)
warnings.warn(message)
# add metrics to the plot
fontsize = plt.rcParams.get('xtick.labelsize')
ax.text(0.05, 0.95, metrics_text, transform=ax.transAxes, fontsize=fontsize, va='top')
[docs]
def plot_subplots(self,
df: pd.DataFrame,
plots: dict,
layout: Optional[tuple] = None,
style: Optional[str] = None,
figsize: Optional[tuple] = None,
hspace: Optional[float] = 0.05,
wspace: Optional[float] = 0.05,
**kwargs: Optional[dict]
) -> tuple[plt.Figure, np.ndarray]:
"""
Plot multiple subplots in a single figure.
Parameters
----------
df : pd.DataFrame
DataFrame containing the data to plot.
plots : dict
Dictionary where keys are subplot titles (showing the title can be supressed for each
subplot with "title = None" within respective entry in the plots dict) and values are
dicts containing keyword arguments for the plot (including plot kind). More information
on available plot kinds and valid arguments can be obtained from the plot() method.
layout : tuple, optional
Tuple containing the layout of the subplots (number of rows, number of columns).
If not provided, the layout will be inferred from the number of plots.
style : str, optional
Style to use for the plot. If not provided, the default style of the Plotter instance
will be used.
figsize : tuple, optional
Tuple containing the size of the figure (width, height) in inches.
If not provided, the default size from the used style will be used.
hspace : float, optional
Height space between the subplots as a fraction of size of the subplot group as a whole.
If there are more than two rows, the hspace is shared between them. Default is 0.05.
More information on padding:
https://matplotlib.org/stable/users/explain/axes/constrainedlayout_guide.html#padding-and-spacing
wspace : float, optional
Width space between the subplots as a fraction of size of the subplot group as a whole.
If there are more than two columns, the wspace is shared between them. Default is 0.05.
More information on padding:
https://matplotlib.org/stable/users/explain/axes/constrainedlayout_guide.html#padding-and-spacing
**kwargs : optional
Additional keyword arguments to pass to the subplot creation of matplotlib.pyplot.subplots().
For more information see:
https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.subplots.html
Returns
-------
fig : matplotlib.figure.Figure
The figure containing the subplots.
axs : numpy.ndarray of matplotlib.axes.Axes
Array of axes objects representing the subplots.
"""
#TODO: It should be possible to provide a plot type once for all subplots
#TODO: Allow to set rcParams other than via style for the entire figure
plots_dict = deepcopy(plots) # original plots dict shall not be changed by this method
# infer layout from number of plots if layout not provided explicitly
if layout is None:
num_plots = len(plots_dict)
layout = (num_plots, 1) if num_plots <= 3 else (2, num_plots // 2 + num_plots % 2)
# check if layout makes sense
if layout[0]*layout[1] < len(plots_dict):
raise ValueError(f"Layout {layout} is too small for the number of plots {len(plots_dict)}.")
if layout[0]*layout[1] > len(plots_dict):
warnings.warn(f"Layout {layout} is too big for the number of plots {len(plots_dict)}.")
# handle plot style
style = self._style_check_n_correct(style)
# all steps that work on the plot should be done in this style context
with plt.style.context(style):
# setup figure and axes
fig, axs = plt.subplots(layout[0], layout[1], **kwargs)
# set figsize if given explicitely
if figsize:
fig.set_size_inches(figsize)
# iterate over the plots to create the subplots
for ax, (title, plot_kwargs) in zip(axs.flatten(), plots_dict.items()):
title = plot_kwargs.pop('title', title)
self.plot(df=df, fig=fig, ax=ax, clip_figure=False, title=title, **plot_kwargs)
# Adjust the layout to ensure appropriate spacing between the subplots
fig.get_layout_engine().set(hspace=hspace, wspace=wspace)
# clip figure if no explicit figsize is given and if ax has a fixed aspect
if figsize is None:
has_fixed_aspect = any(a.get_aspect() != 'auto' for a in axs.flatten())
if has_fixed_aspect:
_clip_fig(fig)
return fig, axs
### PRIVATE METHODS ###
def _get_rcparams_from_kwargs(self, kwargs):
"""
Extract rcParams from kwargs and return them in the correct form for rcParams.
For convenience, aliases are provided for some parameters. Currently available:
- fontsize: font.size (Note: fontsize sets the 'medium' fontsize, not all)
- grid: axes.grid (Note: grid is a boolean, not a string)
Parameters
----------
kwargs : dict
Dictionary containing keyword arguments.
Returns
-------
dict
Dictionary containing rcParams settings.
list
List of rcParams settings that were parsed from kwargs.
"""
# parameter names that shall be known in the plotter and be parsed to rcParams
parse_rcparams = {'fontsize' : 'font.size',
'grid' : 'axes.grid'}
# extract rcParams from kwargs and set them in the correct form for rcParams
update_rcparams = {value: kwargs[key] for key, value in parse_rcparams.items()
if key in kwargs and key in parse_rcparams}
parsed_rcparams = [key for key in parse_rcparams if key in kwargs]
# extract all other rcParams from kwargs, that are already in the form of rcParams
update_rcparams.update({key: value for key, value in kwargs.items()
if key in plt.rcParams})
return update_rcparams, parsed_rcparams
def _style_check_n_correct(self, style: Optional[str]) -> str:
"""
Handle the style for the plot.
Parameters
----------
style : str, optional
Style(s) to use for the plot(s). If None, the default style will be used.
Returns
-------
str
The style to use for the plot.
"""
# check if style exists and set it
if style is None:
style = self.style_default
elif style in self.sat_styles:
path = 'sattoolbox.plots.plot_styles.'
style = [path+'sat_base', path+style]#'sattoolbox.plots.plot_styles.'+style
elif style in plt.style.available:
pass
else:
if not style.endswith('.mplstyle'):
style += '.mplstyle'
try:
with plt.style.context(style):
pass
except OSError as exc:
raise ValueError(f"style must be one of {self.sat_styles}, or any of "+
f"plt.style.available. You provided '{style}'. Custom style sheets"+
" must be in the working dir or the full path must be passed."
) from exc
return style
def _get_setup_and_plot_settings(self,
kwargs):
"""
Split the kwargs into setup and plot settings.
Parameters
----------
kwargs : dict
Dictionary containing keyword arguments.
Returns
-------
setup_settings : dict
Dictionary containing keyword arguments for setting up the plot.
plot_settings : dict
Dictionary containing keyword arguments for plotting the data.
other_kwargs : dict
Dictionary containing keyword arguments that are not used for setting up the plot or
plotting the data.
"""
# only the following arguments will be returned as setup settings or plot settings
setup_settings_names = ['figsize', 'layout']
plot_settings_names = ['legend']
# get all settings: provided settings extend or overwrite default settings
all_settings = self.default_plot_settings.copy()
all_settings.update(kwargs)
# some rcParams are ignored by pandas. Thus, they need to be extracted from plt.rcParams
# and will be set explicitely
rcparams_ignored_by_pandas = {} # key = plot setting, value = rcParam
for key, value in rcparams_ignored_by_pandas.items():
if key not in all_settings:
all_settings[key] = plt.rcParams[value]
setup_settings = {key: all_settings[key] for key in setup_settings_names
if key in all_settings}
plot_settings = {key: all_settings[key] for key in plot_settings_names
if key in all_settings}
other_kwargs = {key: value for key, value in all_settings.items()
if key not in setup_settings.keys() and key not in plot_settings.keys()}
return setup_settings, plot_settings, other_kwargs
def _get_secondary_ax_settings(self,
kwargs: dict,
required_kwargs: Optional[List] = None) -> dict:
"""
Find the settings for the secondary y-axis in the kwargs. All settings that start with
'sec_' or 'secondary_' are considered settings for the secondary y-axis. They are removed
from the passed kwargs dictionary and returned as a separate dicitonary where the prefix is
removed from the keys. If required kwargs are not provided, the secondary axis will not be
created and a warning will be raised.
Parameters
----------
kwargs : dict
Dictionary containing keyword arguments.
required_kwargs : list, optional
List of required keyword arguments for the secondary y-axis. This might be dependent on
the calling method. A Timeseries plot might require 'sec_ycols', while a Scatter plot
might also require 'sec_xcols'.
Returns
-------
dict
Dictionary containing the settings for the secondary y-axis.
"""
#Settings that are defaulted to something else for the secondary axis than for the primary
sec_default_settings = {'grid': False}
if isinstance(required_kwargs, str):
required_kwargs = [required_kwargs]
# A secondary y-axis is used if 'sec_ycols' is provided in kwargs
# kwargs that start with 'sec_' or 'secondary_' are considered settings for the sec y-axis
# They are removed from the original kwargs and returned as sec_settings
# The prefix 'sec_' or 'secondary_' is removed from the keys
sec_settings = {key.split('_', maxsplit=1)[1]: kwargs.pop(key)
for key in kwargs.copy().keys() if key.startswith(('sec_','secondary_'))}
if len(sec_settings)>0:
missing_keys = [key for key in required_kwargs if key not in sec_settings]
# set default settings for the sec axis (only if they are not explicitely provided)
for key, value in sec_default_settings.items():
if key not in sec_settings:
sec_settings[key] = value
if len(missing_keys)>0:
warnings.warn('You provided settings for a secondary y-axis, but did not provide '+
f'the following required kwargs: {required_kwargs}. '+
"The secondary y-axis will not be created.")
sec_settings = {} # reset sec_settings if required ones were not provided
return sec_settings
def _set_physical_quantities(self, physical_quantities: Optional[pd.DataFrame] = None):
"""
Set the physical quantities that shall be known to the Plotter.
This is used for inferring labels for the axes.
Parameters
----------
physical_quantities : pd.DataFrame, optional
DataFrame containing information about physical quantities.
The DataFrame must have columns: 'label', 'color', 'symbol', 'unit'.
If not provided, a default DataFrame will be used.
Raises
------
ValueError
If the physical_quantities DataFrame does not have the required columns.
Returns
-------
None
"""
#TODO: it could be nicer to load this from a file
if physical_quantities is None:
# default physical_quantities DataFrame:
# label = Name + symobl + "in" + unit
# color = default color plots, should be unique for each quantity
# symbol = physical symbol for the quantity
# unit = unit of the quantity (SI or other standard unit)
self.physical_quantities = pd.DataFrame(data={
'T' : ['Temperature T in °C', 'red', r'$T_{SL}$', '°C'],
'dT' : [r'Temperature difference $\Delta$T in K', 'red', r'$\DeltaT$', 'K'],
'p' : ['Pressure p in bar', 'grey', '$p$', 'bar'],
'dp' : [r'Differential pressure $\Delta$p [bar]', 'lightgrey', r'$\Delta$p', 'bar'],
'm_flow' : [r'Mass flow $\dot{m}$ in kg/s', 'green', r'$\dot{m}$', 'kg/s'],
'Q_flow' : [r'Heat flow $\dot{Q}$ in W', 'orange', r'$\dot{Q}$', 'W'],
},
index=['label', 'color', 'symbol', 'unit'])
else:
# check if physical_quantities has the required rows; raise value error otherwise
required_rows = ['label', 'color', 'symbol', 'unit']
if not isinstance(physical_quantities, pd.DataFrame):
raise TypeError("physical_quantities must be a pandas DataFrame.")
if not set(required_rows).issubset(physical_quantities.index):
raise ValueError("physical_quantities DataFrame must have rows: 'label', "+
"'color', 'symbol', 'unit'")
self.physical_quantities = physical_quantities
def _set_default_plot_settings(self, plot_settings: Optional[dict] = None):
"""
Set default plot settings that differ from matplotlib defaults.
Parameters
----------
plot_settings : dict, optional
Dictionary containing default plot settings.
Whatever is not provided, will be taken from a default dictionary.
Returns
-------
None
"""
# Add any default settings
# Important: Only settings, that are not part of matplotlib.rcParams, go here.
# Any rcParams should be set through styles.
default_plot_settings = {
# 'lab_hor_ylabels': False,
}
# create dummy plot_settings if not provided
if plot_settings is None:
plot_settings = {}
# warn if additional settings are provided (they might be ignored by the plot methods)
additional_settings = [key for key in plot_settings if key not in default_plot_settings]
if len(additional_settings) > 0:
warnings.warn("You are introducing the following additional settings to the "+
"'default_plot_settings' (they might be ignored by the plot methods): "+
str(additional_settings))
# integrate specific settings into the default settings
default_plot_settings.update(plot_settings)
# set the default settings of the Plotter object
self.default_plot_settings = default_plot_settings
def _cols_check_n_correct(self,
df : pd.DataFrame,
cols: Optional[Union[str, List[str]]] = None,
required_lengths: Optional[List[int]] = None) -> list:
"""
Return a proper cols list.
- If cols is None, all columns from the DataFrame will be used.
- If cols is a string, it will be wrapped in a list containing the string.
- If cols is a list, it will be left as is.
- The function checks if all the names in cols exist in the DataFrame.
- If any of the columns in cols do not exist in the DataFrame, a warning will be given.
- If cols is empty after removing non-existing columns, a ValueError will be raised.
- If required_lengths is provided, the function will check if the length of cols is in the
required length. If not, a ValueError will be raised.
Parameters
----------
df : pd.DataFrame
The DataFrame to check the columns against.
cols : str or list of str, optional
The columns to be checked and corrected.
required_lengths : list of int, optional
The list of acceptable length of cols. If not provided, no check will be done.
Returns
-------
list of str
The corrected cols list.
"""
if isinstance(cols, str):
cols = [cols]
if cols is None:
cols = list(df.columns)
if not all(col in df.columns for col in cols):
missing_cols = [col for col in cols if col not in df.columns]
warnings.warn("Some of the columns in cols do not exist in the DataFrame:\n"+
f"{missing_cols} \n"+
"Those columns will be removed from cols.")
cols = [col for col in cols if col in df.columns]
if cols == []:
raise ValueError("cols is empty after removing non-existing columns. This means\n"+
"that either the DataFrame is empty or you passed only cols that\n"+
"do not exist in the DataFrame.")
if required_lengths:
if len(cols) not in required_lengths:
raise ValueError(f"cols must have a length of {required_lengths}. You provided "+
f"cols {cols} with a length of {len(cols)}.")
return cols
def _infer_label(self,
cols: list,
label: Optional[str] = None,
axis_context: Optional[str] = 'this axis'
) -> str:
"""
Try to infer a proper label, if it is None from the variables names in cols.
Parameters
----------
cols : list
List of column names that are plotted on the axis.
label : str, optional
Label for the axis. Default is None.
Returns
-------
str
The inferred label for the y-axis.
"""
if label is False:
# If the label is False, no label will be used
return None
if label is not None:
# If a label was explcitely provided, it will be used as label
return label
# Begin constructing a warning message
warn_mess = f"You did not pass a label name for {axis_context}."
label_dict = self.physical_quantities.loc['label', :].to_dict()
# find matches between cols and label_dict keys
matches = []
for y_col in cols:
this_match = _find_longest_match(label_dict.keys(), y_col)
if (this_match is not None) and this_match not in matches:
matches += [this_match]
if len(matches) == 1:
# One match is found --> use it as label
inf_label = label_dict[matches[0]]
warn_mess += f" Label: {inf_label} was inferred from the column names."
else:
# No match or multiple matches were found --> return None
inf_label = None
warn_mess += f" Inferring from the column names {cols} was unsuccessful."
if len(matches) == 0:
warn_mess += " (No match)"
if len(cols) == 1:
inf_label = cols[0] # If only one column is given, use it as label
else:
warn_mess += " (Multiple matches)"
warnings.warn(warn_mess)
return inf_label
# TODO: Title inference should also take xcols into account (e.g. for scatter: 'xcol over ycol')
def _infer_title(self,
plot_type: str,
ycols: list,
style: str,
title: Optional[Union[str, bool]] = None
) -> str:
"""
Try to infer a proper title, if it is None or True for the plot from the plot type
Parameters
----------
plot_type : str
A string that is passed from the calling plot function to name the plot type for a title
e.g. 'Timeseries plot', 'Histogram', 'Scatter plot'
ycols : list
List of column names that are plotted on the axis.
Used in the construction of an inferred title.
style : str
The style of the plot. Used to check if a title should be inferred.
title : str, optional
Title for the plot. Default is None. If True or 'infer', the title will be inferred.
If False or None, no title will be used. If string is passed,
it will be used as title.
Any not-string or any bool(str) != True (e.g. '')
will lead to a behavior depending on the style.
Returns
-------
str
The inferred title for the plot. If no title shall be used, None is returned.
"""
# If the title is a string and not 'infer', it will be used as title
if title and isinstance(title, str) and title != 'infer':
return title
# Check if the title shall be inferred
infer_title = (bool(title is not False and title is not None)
and bool(title is True or title == 'infer' or style[1].split('.')[-1] == 'screen'))
if infer_title:
title = plot_type + " of " + ', '.join([str(col) for col in ycols])
else: # Return None if no title shall be inferred
title = None
return title
def _setup_plot(self,
title: str,
ax: Optional[plt.Axes] = None,
fig: Optional[plt.Figure] = None,
**setup_settings):
"""
Set up a plot with the given title, x-axis label, y-axis label.
The method will create a new figure and axes if none are provided.
Parameters:
----------
title : str
The title of the plot.
xlabel : str
The label for the x-axis.
ylabel : str
The label for the y-axis.
ax : matplotlib.axes.Axes, optional
The axes on which to plot. If not provided, a new figure and axes will be created.
fig : matplotlib.figure.Figure, optional
The figure on which to plot. If not provided, a new figure and axes will be created.
setup_settings : dict
Additional keyword arguments to pass to the plot
Returns:
-------
fig : matplotlib.figure.Figure
The Figure object.
ax : matplotlib.axes.Axes
The Axes object.
"""
# check ax and fig
fig_settings_names = ['figsize', 'layout']
fig_settings = {key: setup_settings[key] for key in fig_settings_names
if key in setup_settings}
if ax is None and fig is None:
fig, ax = plt.subplots(**fig_settings)
elif ax is None or fig is None:
raise ValueError("Both ax and fig must be provided or none of them.")
else:
# raise TypeError if ax is not an Axes object
if not isinstance(ax, plt.Axes):
raise TypeError("ax must be a matplotlib.axes.Axes object.")
# raise TypeError if fig is not a Figure object
if not isinstance(fig, plt.Figure):
raise TypeError("fig must be a matplotlib.figure.Figure object.")
ax.set_title(title)
return fig, ax
def _setup_labels(self,
ax,
xlabel: str,
ylabel: str,
is_sec_axis: bool = False,
xlabel_settings: dict = None,
ylabel_settings: dict = None,
**label_settings):
"""
Set up the labels for the x and y axes.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot.
xlabel : str
The label for the x-axis.
ylabel : str
The label for the y-axis.
is_sec_axis : bool, optional
Indicates, whether the given ax object is a secondary axis (needed for ylabel position)
xlabel_settings : dict, optional
Additional settings for the x-axis label.
ylabel_settings : dict, optional
Additional settings for the y-axis label.
label_settings : dict
Returns
-------
None
"""
# Create a copy to not change the original settings dict
label_settings_ = label_settings.copy()
# Remove ylabel specific parameters from label_settings
ylabel_exclusives = ['hor_ylabels', 'horizontal_ylabels']
custom_ylabel_settings = {key: label_settings_.pop(key) for key in ylabel_exclusives if
(key in label_settings_ or key.startswith('ylabel'))}
if not is_sec_axis:
# Setting xlabel only for primary axis
# Updating general label settings with xlabel specifics. xlabel specifics take precedence.
xlabel_settings = {**label_settings_, **xlabel_settings}
ax.set_xlabel(xlabel, **xlabel_settings)
else:
ax.set_xlabel(xlabel)
# Setting ylabel
# optional settings for horizontal y labels
# Check for and remove the 'hor_ylabels' and 'horizontal_ylabels' settings
if custom_ylabel_settings.pop('hor_ylabels', False) or custom_ylabel_settings.pop('horizontal_ylabels', False):
ylabel_settings.update({'rotation':'horizontal', 'rotation_mode':"anchor",
'verticalalignment':'baseline', 'ha':'left'})
ax.yaxis.set_label_coords(-0.05, 1.02)
if is_sec_axis:
ax.yaxis.set_label_coords(1.05, 1.02)
ylabel_settings.update({'ha':'right'})
ylabel_settings = {**label_settings_, **ylabel_settings}
ax.set_ylabel(ylabel, **ylabel_settings)
[docs]
def get_last_context(self):
"""
Get the last style context used for plotting.
Returns
-------
dict
The last style context used for plotting.
"""
if self._last_context is None:
raise ValueError("No last context exists. Last context available after a plot call.")
return self._last_context.copy()
[docs]
def create_legend(ax, loc='best', labels=None, **kwargs):
"""
Creates a (fancy) legend for the plot. This function is a wrapper around the matplotlib legend
function. It removes the default legend(s) and creates a new one.
Improvements over the matplotlib legend function are:
- The location specifier 'outside' can be used to place the legend outside the plot area.
- The direction of the plot entries is inferred (if not controlled via ncols). E.g. spreading
the entries horizontally if places above or below the plot.
- Legend for two y axes is supported. Entries of both axis are put into the same legend.
Subtitles for the two axes are supported (can be passed via 'title'-kwarg as tuple of two
strings). The entries of the two axis are spread over the columns of the legend.
Parameters
----------
ax : matplotlib.axes.Axes
The axes containing the plot.
loc : str, optional
The location of the legend. Default is 'best'.
labels : list, dict optional
List of labels for the legend. If not provided, the labels of the plot will be used.
If a dictionary is provided, the dictionary is used to update the existing labels.
(Not implemented yet)
**kwargs : dict
Additional keyword arguments to pass to the legend.
Returns
-------
leg : matplotlib.legend.Legend
The legend object.
"""
leg = None
# Remove the default legend(s)
if isinstance(ax, plt.Axes):
leg = _create_single_legend(ax, loc, labels, **kwargs)
elif isinstance(ax, tuple) and all(isinstance(a, plt.Axes) for a in ax):
if len(ax) > 2:
raise NotImplementedError("Legend creation is only supported for two axes at"+
" the moment. A tuple of more than 2 axes was passed.")
leg = _create_double_legend(ax, loc, labels, **kwargs)
else:
raise TypeError("ax must be a matplotlib.axes.Axes object or a tuple of Axes objects.")
return leg
def _create_double_legend(axs, loc, labels, **kwargs):
"""
Create a legend for a plot with two y axes.
Parameters
----------
ax : tuple of matplotlib.axes.Axes
The axes containing the plot.
loc : str
The location of the legend. Default is 'best'.
Possible location specifiers are center, lower, upper, right and left and combinations.
'best' (Axes only), 'upper right', 'upper left', 'lower left', 'lower right', 'right',
'center left', 'center right', 'lower center', 'upper center', 'center'
See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html for more info
Additionally legend placement outside the plot area is possible with:
'outside' as an additional specifier
'This is done by adding 'outside' to the location specifier, e.g. 'outside upper right'.
To the above combinations, the order of left/right and upper/lower can be changed to
have the legend outside where the first specifier signals the side of the plot area and
the second argument the alignment on that respective side.
See https://matplotlib.org/stable/users/explain/axes/legend_guide.html for more info
labels : list, dict
List of labels for the legend. If not provided, the labels of the plot will be used.
If a dictionary is provided, the dictionary is used to update the existing labels.
**kwargs : dict
Additional keyword arguments to pass to the legend.
Returns
-------
leg : matplotlib.legend.Legend
The legend object.
"""
# Remove the legend if present
# Remove the legend from the figure (if it was a figure legend)
leg_fig = [c for c in axs[0].get_figure().get_children()
if isinstance(c, matplotlib.legend.Legend)]
for l in leg_fig:
l.remove()
for ax in axs: #Remove the legend from the axs
try:
ax.get_legend().remove() # Remove if present
except AttributeError:
pass
# See if ncol is explicitely passed
if 'ncols' in kwargs:
ncols = kwargs.pop('ncols')
elif 'ncol' in kwargs:
warnings.warn('"ncol" was passed as a parameter in legend. "ncol" is deprecated.'+
' Use "ncols" instead.')
ncols = kwargs.pop('ncol')
elif loc.startswith('outside right') or loc.startswith('outside left'):
ncols = 1
else:
ncols = 2 # Default for two y axes
# Prepare and draw the legend
h1, l1 = axs[0].get_legend_handles_labels()
h2, l2 = axs[1].get_legend_handles_labels()
# update labels if explicitely given
if labels is not None:
if isinstance(labels, list):
if len(labels) == len(l1) + len(l2):
# Update the labels
l1 = labels[:len(l1)]
l2 = labels[len(l1):]
else:
warnings.warn(f"The number of provided labels ({len(labels)}) does not match the "+
f"total number of legend entries ({len(l1) + len(l2)}). "+
"The labels will not be updated.")
# TODO: update labels based on the passed parameter
elif isinstance(labels, dict):
l1 = [labels.get(label, label) for label in l1]
l2 = [labels.get(label, label) for label in l2]
# warnings.warn("Dictionary updating of legend labels not implemented yet."+
# "The dict is ignored.")
else:
warnings.warn("Labels are not a list. The labels will not be updated.")
# Divide the legend entries among the cols
cols_1 = 1 # Starting number of columns for primary axis entries
if ncols == 1:
n_rows = len(h1) + len(h2)
elif ncols == 2:
n_rows = max(len(h1), len(h2))
else:
# Divide the legend entries among a col_number greater than 2
n_rows = 10**6 # A very large number
while cols_1 < ncols:
n_rows_1 = -(len(h1)//-cols_1)
n_rows_2 = -(len(h2)//-(ncols-cols_1))
n_rows_new = max(n_rows_1, n_rows_2)
# n_row_list.append(-(len(h1)//-cols_1) + -(len(h2)//-(ncols-cols_1)))
# Note: // is the floor division perator the negation makes it ceil division
if n_rows_new > n_rows: # Minimum was found
break
n_rows = n_rows_new
cols_1 += 1
cols_1 = cols_1 - 1
# Get the legend section titles
if 'title' in kwargs and isinstance(kwargs['title'], (tuple, list)) and len(kwargs['title'])==2:
title = kwargs.pop('title')
else:
title = ('Left axis:', 'Right axis:')
# Fill the cols with the legend entries
if ncols == 1:
handles = [title[0]] + h1 + [title[1]] + h2
labels = [''] + l1 + [''] + l2
else:
handles = []
labels = []
col = 0
which = 0
while col < ncols:
if col == cols_1:
which = 1
if which == 0:
h = h1
l = l1
else:
h = h2
l = l2
if col in (0, cols_1):
# Add the Axis label
# handles.append(plt.Line2D([0], [0], color='none'))
# labels.append(which+' axis:')
handles.append(title[which])
labels.append('')
else:
# Add a dummy line for the consecutive columns
handles.append(plt.Line2D([0], [0], color='none'))
labels.append('')
added_rows = 0
while added_rows<n_rows: # Fill the rows
if h and l: #If there are more entries to get
# Get the entries
handles.append(h.pop(0))
labels.append(l.pop(0))
else:
# Fill them with dummy lines
handles.append(plt.Line2D([0], [0], color='none'))
labels.append('')
added_rows += 1
col += 1
class LegendTitle(object):
"""
A class to create a custom legend (sub)title for matplotlib plots. This title is
implemented as an entry into the legend, where the text is shown at handle position
Attributes:
-----------
text_props : dict, optional
A dictionary of text properties to customize the appearance of the legend title.
Returns:
--------
title : matplotlib.text.Text
The text artist representing the legend title.
"""
def __init__(self, text_props=None):
self.text_props = text_props or {}
super(LegendTitle, self).__init__()
def legend_artist(self, legend, orig_handle, fontsize, handlebox):
x0, y0 = handlebox.xdescent, handlebox.ydescent
title = matplotlib.text.Text(x0, y0, orig_handle, **self.text_props)
handlebox.add_artist(title)
return title
# Create the legend
leg = None
# Get the fontsize for legend entries (this is necessary for the title to have the same size)
fontsize = kwargs.get('fontsize', plt.rcParams['legend.fontsize'])
if loc.startswith('outside'):
# Turn the frame off, unless explicitely stated otherwise
kwargs['frameon'] = kwargs.get('frameon', False)
if 'mode' not in kwargs and loc.split(' ')[1] in ['upper', 'lower']:
kwargs['mode'] = 'expand'
if loc.split(' ')[1] in ['left', 'lower', 'right']:
# in that case a figure legend is necessary to have the legend properly outside
leg = axs[1].get_figure().legend(handles, labels, ncols=ncols, loc=loc,
handler_map={str: LegendTitle({'fontsize':fontsize})}, **kwargs)
else:
loc = _get_legend_loc(loc, kwargs.pop('bbox_to_anchor', None))
leg = axs[1].legend(handles, labels, ncols=ncols, **loc,
handler_map={str: LegendTitle({'fontsize':fontsize})}, **kwargs)
else:
leg = axs[1].legend(handles, labels, ncols=ncols, loc=loc,
handler_map={str: LegendTitle({'fontsize':fontsize})}, **kwargs)
return leg
def _create_single_legend(ax, loc, labels, **kwargs):
"""
Create a legend for a plot with a single y axis.
Parameters
----------
ax : matplotlib.axes.Axes
The axes containing the plot.
loc : str
The location of the legend.
labels : list, dict optional
List of labels for the legend. If not provided, the labels of the plot will be used.
If a dictionary is provided, the dictionary is used to update the existing labels.
**kwargs : dict
Additional keyword arguments to pass to the legend.
Returns
-------
leg : matplotlib.legend.Legend
The legend object.
"""
if isinstance(labels, dict):
# If a dictionary is provided, the labels are updated with the dictionary
l_old = ax.get_legend_handles_labels()[1]
l_new = [labels.get(label, label) for label in l_old]
labels = l_new
# warnings.warn("Dictionary updating of legend labels not implemented yet."+
# "The dict is ignored.")
# labels=None
leg_fig = [c for c in ax.get_figure().get_children()
if isinstance(c, matplotlib.legend.Legend)]
for l in leg_fig:
l.remove()
try:
ax.get_legend().remove() # Remove if present
except AttributeError:
pass
if loc.startswith('outside'):
# Adjust the direction of legend entries
if 'ncols' not in kwargs and loc[8:13] in ['lower', 'upper']:
kwargs['ncols'] = len(ax.get_legend_handles_labels()[0])
# Turn the frame off, unless explicitely stated otherwise
kwargs['frameon'] = kwargs.get('frameon', False)
if loc[8:12] in ['left', 'lowe']:
# in that case a figure legend is necessary to have the legend outside
leg = ax.get_figure().legend(labels=labels, loc=loc, **kwargs)
else:
loc = _get_legend_loc(loc, kwargs.pop('bbox_to_anchor', None))
leg = ax.legend(labels=labels, **loc, **kwargs)
else:
leg = ax.legend(labels=labels, loc=loc, **kwargs)
return leg
def _get_legend_loc(loc, bbox_to_anchor=None):
"""
Get the location of the legend.
Parameters
----------
loc : str
The location of the legend.
bbox_to_anchor : tuple, optional
The bbox_to_anchor of the legend. Default is None.
Returns
-------
dict
A dict containing the location of the legend.
In 'normal' cases this is a dict with the key 'loc' and the value loc.
In 'outside' cases this is a dict with the key 'loc' and the key 'bbox_to_anchor'
with their specific values.
"""
custom_locs = {'outside upper left': {'loc': 'lower left', 'bbox_to_anchor': (0, 1, 1, 1)},
'outside upper right': {'loc': 'lower right', 'bbox_to_anchor': (0, 1, 1, 1)},
'outside upper center': {'loc': 'lower center', 'bbox_to_anchor': (0, 1, 1, 1)},
'outside lower left': {'loc': 'upper left', 'bbox_to_anchor': (0, 0, 1, 0)},
'outside lower right': {'loc': 'upper right', 'bbox_to_anchor': (0, 0, 1, 0)},
'outside lower center': {'loc': 'upper center', 'bbox_to_anchor': (0, 0, 1, 0)},
'outside right upper': {'loc': 'upper left', 'bbox_to_anchor': (1, 1)},
'outside right lower': {'loc': 'lower left', 'bbox_to_anchor': (1, 0)},
'outside right center': {'loc': 'center left', 'bbox_to_anchor': (1, 0.5)},
'outside left upper': {'loc': 'upper right', 'bbox_to_anchor': (0, 1)},
'outside left lower': {'loc': 'lower right', 'bbox_to_anchor': (0, 0)},
'outside left center': {'loc': 'center right', 'bbox_to_anchor': (0, 0.5)}}
if isinstance(loc, dict):
pass
elif bbox_to_anchor is not None:
loc = {'loc': loc, 'bbox_to_anchor': bbox_to_anchor}
# check if loc is a custom location
elif loc in custom_locs:
loc = custom_locs[loc]
# loc['bbox_transform'] = plt.figure(fig.number).transFigure
else: # Wrap in a dict
loc = {'loc': loc}
return loc
### HELPER FUNCTIONS ###
def _extract_kwargs_with_prefix(kwargs, prefix):
"""
Extract keyword arguments from kwargs that start with a certain prefix.
The keyword arguments are removed from the kwargs. (The kwargs-object is changed)
Parameters
----------
kwargs : dict
Dictionary containing keyword arguments.
prefix : str, list, tuple
The prefix to search for in the keys of the kwargs.
Returns
-------
dict
Dictionary containing the keyword arguments that start with the prefix.
"""
if isinstance(prefix, tuple):
prefix = list(prefix)
elif isinstance(prefix, str):
prefix = [prefix]
i = 0
while i < len(prefix):
if not prefix[i].endswith('_'):
prefix[i] = prefix[i] + '_' # add an underscore at the end
i += 1
prefix = tuple(prefix)
extracted_settings = {key.split('_', maxsplit=1)[1]: kwargs.pop(key)
for key in kwargs.copy().keys() if key.startswith(prefix)}
return extracted_settings
def _find_longest_match(searchlist, match, mode='start'):
"""
Find element from searchlist that has longest match with match:
- the element from searchlist must be totally in match
- the mathing starts either at the start (default) or at the end of the element from
searchlist, this is using .startswith() or .endswith() respectively.
This method is used to infer labels for the axes:
consider: searchlist = ['Q','Q_flow','Q_flow_loss'] and
- match1 = 'Q_flow_loss_xy' => returns 'Q_flow_loss'
- match2 = 'Q_flow_loss' => returns 'Q_flow_loss'
- match3 = 'Q_flow_xy' => returns 'Q_flow'
- match4 = 'Q_xy' => returns 'Q'
Parameters
----------
searchlist : list
list of strings to search longest match in.
match : str
search string.
mode : str
whether to match from 'start' or 'end' of the items in searchlist
Returns
-------
item : str
longest match if any, None otherwise.
"""
searchlist_revsorted = list(searchlist).copy()
searchlist_revsorted.sort(key=len,reverse=True)
match = str(match) # cast match to string
for item in searchlist_revsorted:
if mode == 'start':
if match.startswith(item):
return item
elif mode == 'end':
if match.endswith(item):
return item
else:
raise ValueError("'mode' must be one of 'start' or 'end'. You provided '"+mode+"'.")
return None
def _get_valid_plot_kinds_pandas():
"""
Get the valid plot kinds from the pandas documentation.
This is done by parsing the docstring of the pandas plot function.
The docstring is split into lines and the lines that contain the valid kinds are extracted.
Parameters
----------
None
Returns
-------
valid_kinds : list
List of valid plot kinds.
"""
# Get the docstring of the pandas plot function
docstring_lines = pd.DataFrame().plot.__doc__.split('\n')
docstring_lines = [line.strip() for line in docstring_lines if line.strip() != '']
# find the lines that contain the valid kinds => lines after 'kind : str', starting with "-"
# and before the next line that starts not with "-"
for i, line in enumerate(docstring_lines):
if 'kind : str' in line:
for j in range(i+1, len(docstring_lines)):
if docstring_lines[j].startswith('-'):
start_line = j
break
for j in range(start_line, len(docstring_lines)):
if not docstring_lines[j].startswith('-'):
end_line = j
break
valid_kinds = []
# extract the valid kinds from the lines between start_line and end_line
# The lines start with "-" and contain the kind in the form of "- 'kind'"
for line in docstring_lines[start_line:end_line]:
if line.startswith('-'):
kind = line.split("'")[1]
valid_kinds.append(kind)
return valid_kinds
def _clip_fig(fig):
"""
Clip the figure to its tight bounding box and adjust the size accordingly.
This is useful when the figure has axes with a fixed aspect ratio, to remove white margins.
If no axes have a fixed aspect ratio, no clipping is needed.
"""
# check if any of the axes in the figure have a fixed aspect
is_fixed = any(ax.get_aspect() != 'auto' for ax in fig.axes)
if not is_fixed:
# If no axes have a fixed aspect, no clipping is needed
warnings.warn("No axes with fixed aspect ratio found. No clipping needed. If you want to " \
"change the figure size, you can use the parameter figsize=(width, height) in the call " \
"of plotter.plot().")
return
# Check if colorbars are present and more than one subplot (clipping might not work well then)
has_colorbar = any(hasattr(ax, '_colorbar') for ax in fig.axes)
if has_colorbar and len(fig.axes) > 2:
warnings.warn("Clipping might not work well with colorbars and multiple subplots. Maybe " \
"you must adjust the figsize manually to fit your needs, e.g. by adding " \
"the parameter figsize = (width, height) (in inches) to the call of " \
"plotter.plot().")
# Ensure the figure is drawn so that the actual layout is done to get the correct bbox
fig.canvas.draw()
# Get the tight bounding box
bbox = fig.get_tightbbox(fig.canvas.get_renderer())
width, height = bbox.width, bbox.height
# Get the factor to maintain the figure dimension that must not be trimmed
# (otherwise the resulting figure would be smaller than before)
factor = min(fig.get_size_inches()[0] / width, fig.get_size_inches()[1] / height)
# Do the actual clipping
fig.set_size_inches(width * factor, height * factor)
# draw the figure again to apply the new size
fig.canvas.draw()
### EXAMPLE USAGE ###
[docs]
def example(example_types=['all']):
"""
Example usage of the Plotter class.
Parameters
----------
example_types : list, optional
List of plot types to demonstrate. Default is ['all']. Possible values are:
'all', 'style_examples', 'pd_kinds', 'multi_scatter', 'R2_plots', 'sec_axis', 'subplots',
'cycler'.
"""
# Create a sample DataFrame
data_list = list(range(48))
data = {'T_1': data_list,
'T_2': [sqrt(x) for x in data_list],
'T_3': [sqrt(x)+5 for x in data_list],
'm_flow': [10+x for x in data_list],}
hours = pd.date_range(start="2023-01-01", end="2023-01-02 23:00", freq="1h")
df = pd.DataFrame(data, index=hours)
# Create an instance of Plotter
plotter = Plotter()
plotter_presentation = Plotter(style_default='presentation')
plotter_paper = Plotter(style_default='paper')
# demonstrate different styles, behavior without specifing ycols
if 'all' in example_types or 'style_examples' in example_types:
plotter.plot(df)
plotter_presentation.plot(df)
plotter.plot(df, style='presentation')
plotter.plot(df, style='paper')
# behavior with specifing title as 'infer'
plotter.plot(df, style='paper', title='infer')
# behavior with specifing ycols
plotter.plot(df, style='paper', ycols=['T_1','T_2'])
# behavior with specifing ycols and ylabel:
plotter.plot(df, ycols=['T_1','T_2'], ylabel='Temperatur', title='Custom Title')
plotter.plot(df, ycols=['T_1','T_4']) # specify a non-existing column
# Plot data using plot_general method with different kind of plots
if 'all' in example_types or 'pd_kinds' in example_types:
plotter.plot(df, ycols=['T_1','T_2'], title='Line plot', kind = 'line')
plotter.plot(df, xcols= 'T_1', ycols='T_2', title='Bar plot', kind = 'bar')
plotter.plot(df, xcols= 'T_1', ycols='T_2', title='Hor. Bar plot', kind = 'barh')
plotter.plot(df, ycols=['T_1','T_2'], title='Hist plot', kind = 'hist', alpha=0.5)
plotter.plot(df,ycols=['T_1','T_2'], title='Box plot', kind = 'box')
plotter.plot(df, ycols=['T_1','T_2'], title='KDE plot', kind = 'kde')
plotter.plot(df, ycols=['T_1','T_2'], title='Density plot', kind = 'density')
plotter.plot(df, ycols=['T_1','T_2'], title='Area plot', kind = 'area')
plotter.plot(df.iloc[0:5,:], ycols='T_1', title='Pie plot', kind = 'pie')
plotter.plot(df, title='Scatter plot', kind = 'scatter',
xcols= 'T_1', ycols='T_2', s='T_1', c='m_flow',
marker='x', linestyle='solid')
plotter.plot(df, xcols= 'T_1', ycols='T_2', title='Hexbin plot', kind = 'hexbin')
# examples with multi_scatter
if 'all' in example_types or 'multi_scatter' in example_types:
plotter.plot(df, xcols= 'T_1', ycols=['T_2','T_3'], title='Multi scatter plot',
kind = 'multi_scatter', legend_loc='outside upper left',
marker='cycle')
plotter.plot(df, xcols= ['T_1','T_2'], ycols=['T_2','T_3'],
title='Multi scatter plot', kind = 'multi_scatter',
leg_labels = ['T_2 over T_1', 'T_3 over T_2'],
marker=['x', 'o'],
legend_loc='lower right')
# examples with R2 plots
if 'all' in example_types or 'R2_plots' in example_types:
# simple R2 plots
plotter.plot(df, xcols= 'T_2', ycols='T_2', title='R2 scatter plot', kind = 'R2_scatter')
plotter.plot(df, xcols= 'T_2', ycols='T_2', title='R2 hexbin plot', kind = 'R2_hexbin')
# examples with more customization
plotter_presentation.plot(df, xcols= 'T_2', ycols='T_2', kind = 'R2_scatter',
title='R2 scatter plot, customized, presentation', metrics=False,
marker='x', color='red', alpha=0.5, s='T_1',
xlabel='measured temperature', ylabel='modeled temperature',
xlim=[0,10], ylim=[0,10])
plotter.plot(df, xcols= 'T_2', ycols='T_2', kind = 'R2_hexbin',
title='R2 hexbin plot: customized, paper', metrics = ['R2','RMSE'],
cmap='Reds', gridsize=20,
xlabel='measured temperature', ylabel='modeled temperature',
xlim=[0,10], ylim=[0,10], style = 'paper')
# examples with secondary axis
if 'all' in example_types or 'sec_axis' in example_types:
plotter_presentation.plot(
df, title='Timeseries with secondary y-axis (and outside legend)',
ycols=['m_flow'], sec_ycols=['T_1','T_2'], sec_marker='x', sec_linestyle='',
legend_loc='outside right upper'
)
# same with horizontal y labels
plotter_presentation.plot(
df, title='', lab_hor_ylabels = True,
ycols=['m_flow'], sec_ycols=['T_1','T_2'], sec_marker='x', sec_linestyle='',
legend_loc='outside right upper'
)
plotter.plot(df, kind = 'multi_scatter', title='Multi-Scatter plot',
xcols= 'T_1', ycols='T_2',
xlim = [10,20],
sec_xcols = 'T_1', sec_ycols='m_flow',
legend_loc='outside lower left', legend_ncols=7)
# Note that xcols and sec_xcols are different physical quantities
# => no xlabel will be printed, unless defined explicitely
plotter.plot(df, xcols= ['T_1','T_2'], ycols=['T_2','T_3'],
sec_xcols = 'm_flow', sec_ycols='T_2',
title='Multi scatter plot', kind = 'multi_scatter',
leg_labels = ['T_2 over T_1', 'T_3 over T_2', 'T_2 over m_flow'],
legend_loc='lower right',
#xlabel = "T or m_flow"
)
# examples with subplots
if 'all' in example_types or 'subplots' in example_types:
plots = {
'Plot a timeseries': {'ycols':['T_1', 'T_2'], 'kind':'line'},
'Histogram': {'ycols':'m_flow', 'kind':'hist'}
}
plotter.plot_subplots(df, plots)
plotter.plot_subplots(df, plots, layout=(1,2), figsize=(8,4), wspace=0)
# behavior with more than 3 plots => layout is inferred
plots_2 = {
'Plot a timeseries': {'ycols':['T_3', 'T_2'], 'kind':'line'},
'Histogram': {'ycols':'T_1', 'kind':'hist'},
'Histogram 2': {'ycols':'T_2', 'kind':'hist', 'title':None},
'Histogram 3': {'ycols':'m_flow','color':'purple','bins':20, 'grid':False, 'kind':'hist'}
}
plotter.plot_subplots(df, plots_2) # behavior with more than 3 plots => layout is inferred
# Example on cycler usage
if 'all' in example_types or 'cycler' in example_types:
# Create a sample DataFrame
x = np.linspace(0, 3*np.pi, 20)
y = {}
for i in range(10):
y[f'line_{i}'] = np.sin(x) + i/2
df2 = pd.DataFrame(y, index=x)
# Create an instance of Plotter
plotter_cycler = Plotter(style_default='presentation')
# Plot data using plot_data method
# No arguments
plotter_cycler.plot(df2, title='Cycler example without arguments')
# Different lengths arguments for linestyle, marker and color
plotter_cycler.plot(df2, title='Cycler example with different lengths props',
linestyle=['dashed', 'dotted'],
marker=['o', 'x', 's'],
color=['red', 'blue', 'green', 'orange'])
# With secondary axis
plotter_cycler.plot(df2, ycols=list(df2.columns[:5]), sec_ycols=list(df2.columns[6:]),
title='Cycler example with secondary axis',
linestyle=['dashed', 'dotted'], ylim=[-1,6],
marker=['o', 'x', 's'], sec_linestyle='',
color=['red', 'blue', 'green', 'orange'])
# Using the 'cycle' keyword
plotter_cycler.plot(df2, ycols=list(df2.columns[:5]), sec_ycols=list(df2.columns[6:]),
title='Cycler example with secondary axis and linestyle "cycle"',
linestyle='cycle', ylim=[-1,6],
color='cycle', sec_color='red')
plotter_cycler.plot(df2, ycols=list(df2.columns[:5]), sec_ycols=list(df2.columns[6:]),
title='Cycler example with secondary axis and marker "cycle"',
linestyle='', marker='cycle', ylim=[-1,6],
color='cycle', sec_color='red', markersize=[10,15])
# Using predefined cycles
plotter_cycler.plot(df2,
title='Cycler example with "filled" marker, "main" linestyle and "rgb" color',
marker='filled', linestyle='main', color='rgb')