# -*- coding: utf-8 -*-
"""
Created on Tue Aug 9 14:43:08 2022
Helper functions for plotting from Felix Micus (used in fm_plots)
@author: FM
"""
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math
import matplotlib.dates as mdates
import matplotlib.ticker as mticker
debug_mode = False
[docs]
def label_line(line, ax, label, color='black', size=10, position=0.5, offset=None, verticalalignment="bottom"):
"""
Parameters
----------
line : matplotlib.lines.Line2D object
The line to which the label will be attached.
ax : matplotlib.axes.Axes object
The axes on which the line is plotted.
label : str
The text to be displayed as the label.
color : str, optional
The color of the text (default is 'black').
size : float, optional
The font size of the text (default is 10).
position : float, optional
The relative position along the line where the label will be placed,
represented as a fraction of the line's total length (default is 0.5,
indicating the middle of the line).
offset : numerical, optional
Offset of the label from the line (The offset will be place perpendicular from the point that is labeled).
The direction of the offset depends on the defined vertical aligntment. Labels above the line will be offset up,
labels below the line will be offset down.
The default is None.
verticalalignment : String, optional
vertival alignment of label. Choice of "bottom", "baseline", "center", "top". The default is "bottom".
Returns
-------
text : matplotlib.text.Text object
The text object representing the label.
Example
-------
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
fig, ax = plt.subplots()
ax.plot(x, y)
label_line(line=ax.lines[0], ax=ax, label='Peak', color='red', size=12, position=0.75)
plt.show()
"""
# Extract x and y data from the line
xdata, ydata = line.get_data()
length = len(xdata)-1
# Calculate the index corresponding to the specified position along the line
pos = int(length * position)
if pos >= length:
pos = pos-1
# Get the coordinates of two points on the line around the specified position
x1 = xdata[pos]
y1 = ydata[pos]
x2 = xdata[pos + 1]
y2 = ydata[pos + 1]
# Determine the alignment of the text based on its position along the line
if position <= 0.1:
horizontalalignment = 'left'
elif position >= 0.9:
horizontalalignment = 'right'
else:
horizontalalignment = 'center'
# Transform the data coordinates to display coordinates for the two points
sp1 = ax.transData.transform((x1, y1))
sp2 = ax.transData.transform((x2, y2))
# Calculate the rise and run to determine the slope of the line
rise = (sp2[1] - sp1[1])
run = (sp2[0] - sp1[0])
# Calculate the slope angle in degrees
slope_angle = np.arctan2(rise, run)
slope_degrees = np.degrees(slope_angle)
# Calculate the offset based on the rotation and the aligntment
xytext = (0, 0)
if offset is not None:
if verticalalignment in ["bottom", "baseline", "center"]:
x_offset = - np.sin(slope_angle)*offset
y_offset = np.cos(slope_angle)*offset
elif verticalalignment == "top":
x_offset = np.sin(slope_angle)*offset
y_offset = - np.cos(slope_angle)*offset
xytext = (x_offset, y_offset)
text = ax.annotate(label, xy=(x1, y1), xytext=xytext,
rotation = slope_degrees, size=size, color=color,
textcoords='offset fontsize',horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment, transform_rotates_text=True,
rotation_mode="anchor")
# # Set the rotation of the label to match the slope angle
if debug_mode:
df = pd.DataFrame({"x":[x1,x2,sp1[0], sp2[0], run],"y":[y1,y2, sp1[1], sp2[1], rise]}, index=[str(pos), str(pos+1), "sp1", "sp2", "delta"])
print("Point: "+str(label))
print(df)
print("degrees = "+str(slope_degrees))
ax.scatter(x=[x1,x2], y=[y1,y2], color=color, marker="x", s=5)
if offset is not None:
print("offset = "+str(xytext))
print("--------------")
return text
[docs]
def scale_datetime_axis(ax, start_time, end_time, style="tight", label_style="shortname", minor_labels=True, year=False):
"""
Manually scale and label the datetime axis of a plot with a datetime x-Axis. Works for Periods from 1/2 a day to >2 years.
Parameters
----------
ax : Axes Obejct
Axes object on which to scale the x-axis.
start_time : Datetime
Starting value (left limit of x-axis).
end_time : Datetime
End value (right limit of x-axis).
style : String, optional
Style for the placement of the labels. Implemented are "wide" and "tight". The default is "tight".
label_style : String, optional
Style to represent the dates for month and day of week. Implemented are:
"shortname": Sep, Oct, Tue, Mon,
"name": April, August, Monday
"number": 02. , 06.
The default is "shortname".
minor_labels : Boolean, optional
If minor ticks should get labels. The default is True.
year : Boolean, optional
If the year should be added to dates. The default is False.
Returns
-------
bool
DESCRIPTION.
"""
month = {"shortname": " %b", "name": " %B", "number":"%m."}[label_style]
day = {"shortname": "%a", "name": "%A", "number":"%d."}[label_style]
if year:
y = "%y"
else:
y = ""
time_delta = (end_time-start_time)
days = time_delta.total_seconds() / (24*60*60)
# hour tick labels
if style=="wide":
if days < 1.5:
hours = days*24
interval = int( max(hours // 20, 1) )
ax.xaxis.set_major_locator(mdates.HourLocator(interval = interval))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
if hours < 15:
ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(n=2))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%H:%M"))
elif days < 12:
# Find an amount of ticks between 10 and 20
hours = days*24
interval = int( max(hours // 20, 1) )
n = int(24 // days)
ax.xaxis.set_major_locator(mdates.DayLocator(interval = 1))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%d.%m."+y))
ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(n=n))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%H:%M"))
elif days < 50:
ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday= 0))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%a %d.%m."+y))
# n = int(100 // days)
ax.xaxis.set_minor_locator(mdates.DayLocator(interval = 1))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%d.%m."))
elif days < 100:
ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday = [1,11,21]))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%d.%m."+y))
if minor_labels:
ax.xaxis.set_minor_locator(mdates.DayLocator(interval = 1))
elif days < 500:
ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday = 1))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%d.%m."+y))
if days < 200:
ax.xaxis.set_minor_locator(mdates.DayLocator(bymonthday = [11,21]))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%d.%m."))
elif days < 400:
ax.xaxis.set_minor_locator(mdates.DayLocator(bymonthday = 15))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%d.%m."))
pass
elif style=="tight":
if days < 1.5:
hours = days*24
interval = int( max(hours // 20, 1) )
ax.xaxis.set_major_locator(mdates.HourLocator(interval = interval))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
if hours < 15:
ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(n=2))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%H:%M"))
elif days < 12:
# Find an amount of ticks between 10 and 20
hours = days*24
interval = int( max(hours // 20, 1) )
n = int(30 // days)
ax.xaxis.set_major_locator(mdates.DayLocator(interval = 1))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%#d."+month+y))
ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(n=n))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%H:%M"))
elif days < 50:
ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday= 0))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%a %#d.%m."+y))
# n = int(100 // days)
ax.xaxis.set_minor_locator(mdates.DayLocator(interval = 1))
if minor_labels:
if label_style == "shortname":
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%a"))
else:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%#d.%m."))
elif days < 100:
ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday = [1,11,21]))
if label_style=="shortname":
ax.xaxis.set_major_formatter(mdates.DateFormatter("%#d."+month+" "+y))
else:
ax.xaxis.set_major_formatter(mdates.DateFormatter("%#d.%m."+y))
if minor_labels:
ax.xaxis.set_minor_locator(mdates.DayLocator(interval = 1))
elif days < 500:
ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday = 1))
if label_style=="shortname":
ax.xaxis.set_major_formatter(mdates.DateFormatter(("%#d."+month+" "+y).strip()))
else:
ax.xaxis.set_major_formatter(mdates.DateFormatter("%#d.%m."+y))
if days < 200:
ax.xaxis.set_minor_locator(mdates.DayLocator(bymonthday = [11,21]))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%#d.%m."))
elif days < 400:
ax.xaxis.set_minor_locator(mdates.DayLocator(bymonthday = 15))
if minor_labels:
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%#d.%m."))
ax.set_xlim(start_time, end_time)
return True
[docs]
def is_iterable_of_strings(obj):
"""
Check whether the passed object is an iterable of strings with a white list. If True, it is an iterable of strings, if False, it might not be, but could be.
isinstance(obj, (list, tuple, set, frozenset)) and all(isinstance(item, str) for item in obj)
Parameters
----------
obj : Object
DESCRIPTION.
Returns
-------
Boolean
Whether the passed objet is on the white list of iterabel Strings.
"""
return (isinstance(obj, (list, tuple, set, frozenset)) and all(isinstance(item, str) for item in obj))
if __name__ == '__main__':
# =============================================================================
# Testing label line function
# =============================================================================
print("Testing label line function")
debug_mode = True # Prints some information and adds points on the line to show which point is labeled
print("Creating test data (sin curve)")
x = np.linspace(0, 3*np.pi, 1000)
y = np.sin(x)
fig, ax = plt.subplots()
ax.plot(x, y)
ax.set_xlim(0,3*np.pi)
ax.set_ylim(-1.2,1.2)
print("Labeling various points")
label_line(line=ax.lines[0], ax=ax, label='0.01', color='green', size=12, position=0.01)
label_line(line=ax.lines[0], ax=ax, label='0.5', color='red', size=12, position=0.51, verticalalignment="top", offset=1)
label_line(line=ax.lines[0], ax=ax, label='0.7', color='red', size=12, position=0.7, offset=-3, verticalalignment="bottom")
label_line(line=ax.lines[0], ax=ax, label='0.99', color='red', size=12, position=0.99, verticalalignment="top", offset=0.1)
plt.show()