"""
Helper functions to produce nicer plots
"""
from scipy import stats
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
import seaborn as sns
from collections.abc import Hashable
[docs]def joint_density_scatter(x, y, data=None, max_dens_npoints=300, diagonal=False, xlim=None, ylim=None,
hue=None, cmap='gist_heat', colors=None, **kwargs):
"""
Produces a scatter plot color-coded by density jointly with histograms
:param x: (N, ) array or series with x-values (or string if data is not None)
:param y: (N, ) array or series with y-values (or string if data is not None)
:param data: pandas dataframe with the data
:param ylog: if True, makes the x-axis logarithmic
:param max_dens_npoints: maximum number of points to use in the density calculation.
Computation time on my laptop is roughly (max_dens_points / 600) seconds.
:param diagonal: If True plots a diagonal straight line illustrating where both axes are equal
:param xlim: lower and upper bound in the x-direction
:param ylim: lower and upper bound in the y-direction
:param hue: (N, ) categorical array/series to set the color coding (or string if data is not None)
:param colors: seaborn color pallette or sequence of colors (default: matplotlib color cycle)
:param kwargs: keyword arguments passed on to plt.scatter.
Particularly useful are adjusting the size (e.g. s=1) or the opacity (e.g. alpha=0.3).
:return: seaborn JointGrid object with scatter plot and histograms
"""
colors = sns.color_palette(colors)
jg = sns.JointGrid(x, y, data=data, xlim=xlim, ylim=ylim)
if data is not None:
if isinstance(x, Hashable) and x in data.columns:
x = data[x]
elif isinstance(x, str) or isinstance(x, tuple):
raise ValueError(f"{x} not in pandas dataframe")
if isinstance(y, Hashable) and y in data.columns:
y = data[y]
elif isinstance(y, str) or isinstance(y, tuple):
raise ValueError(f"{y} not in pandas dataframe")
if isinstance(hue, Hashable) and hue in data.columns:
hue = data[hue]
elif isinstance(hue, str) or isinstance(hue, tuple):
raise ValueError(f"{hue} not in pandas dataframe")
x_range = np.ones(x.shape, dtype='bool') if xlim is None else (x >= xlim[0]) & (x <= xlim[1])
y_range = np.ones(y.shape, dtype='bool') if ylim is None else (y >= ylim[0]) & (y <= ylim[1])
in_range = x_range & y_range
density_scatter(x[in_range], y[in_range], axes=jg.ax_joint, max_dens_npoints=max_dens_npoints,
diagonal=diagonal, hue=hue, cmap=cmap, colors=colors, **kwargs)
if hue is None or np.unique(hue).size == 1:
sns.distplot(x[x_range], ax=jg.ax_marg_x, color='r', axlabel=False)
sns.distplot(y[y_range], ax=jg.ax_marg_y, color='r', vertical=True, axlabel=False)
else:
if len(np.unique(hue)) > len(colors):
raise ValueError("More unique values in hue than values in the palette")
for color, val in zip(colors, np.unique(hue)):
use = (hue == val) & np.isfinite(x) & np.isfinite(y)
sns.distplot(x[use & x_range], ax=jg.ax_marg_x, axlabel=False, color=color)
sns.distplot(y[use & y_range], ax=jg.ax_marg_y, vertical=True, axlabel=False, color=color)
return jg
def _get_kde(x, y, max_dens_npoints=300):
"""
Returns a KDE
:param x: x-values
:param y: y-values
:param max_dens_npoints: maximum number of points to use in the density calculation
Computation time on my laptop is roughly (max_dens_points / 600) seconds
:return: density estimator
"""
mask = (
np.isfinite(x) & (x != 0) &
np.isfinite(y) & (y != 0)
)
subsample_density = max(int(round(mask.sum() / max_dens_npoints)), 1)
x_for_dens = x[mask][::subsample_density]
y_for_dens = y[mask][::subsample_density]
try:
return stats.gaussian_kde([x_for_dens, y_for_dens])
except np.linalg.LinAlgError:
return lambda coords: np.ones(coords[0].shape)
[docs]def density_scatter(x, y, xlog=False, ylog=False, max_dens_npoints=300, axes=None,
diagonal=False, hue=None, colors=None, cmap='gist_heat', **kwargs):
"""
Produces a scatter plot color-coded by density
:param x: (N, ) array with the x-values
:param y: (N, ) array with the y-values
:param xlog: if True, makes the x-axis logarithmic
:param ylog: if True, makes the x-axis logarithmic
:param max_dens_npoints: maximum number of points to use in the density calculation.
Computation time on my laptop is roughly (max_dens_points / 600) seconds
:param axes: matplotlib Axes object, which will be used for the scatter plot (default: current axis)
:param diagonal: If True plots a diagonal straight line illustrating where both axes are equal
:param hue: (N, ) index array with a categorical variable indicating the hue
:param colors: seaborn color pallette or sequence of colors (default: matplotlib color cycle)
:param cmap: color-map to used if hue is not set
:param kwargs: keyword arguments passed on to plt.scatter.
Particularly useful are adjusting the size (e.g. s=1) or the opacity (e.g. alpha=0.3)
:return: scatter object
"""
colors = sns.color_palette(colors)
if x.ndim != 1 or y.ndim != 1:
raise ValueError("Can only work with 1-dimensional input")
if x.size != y.size:
raise ValueError("Arrays for x and y should have the same number of elements")
if hue is not None and np.unique(hue).size > len(colors):
raise ValueError("Categorical variable has more values than options available in the color")
if axes is None:
axes = plt.gca()
xval = np.log(x) if xlog else x
yval = np.log(y) if ylog else y
density = _get_kde(xval, yval, max_dens_npoints)([xval, yval])
if hue is None or np.unique(hue).size == 1:
color = plt.get_cmap(cmap)(density / max(density))
else:
full_dens = np.stack([
_get_kde(xval[hue == val], yval[hue == val])([xval, yval]) for val in np.unique(hue)
], 0)
full_dens /= full_dens.sum(0)[None, :]
inital_rgb = (full_dens[:, None, :] * np.array(colors)[:full_dens.shape[0], :, None]).sum(0)
hls = np.vectorize(sns.utils.colorsys.rgb_to_hls)(inital_rgb[0], inital_rgb[1], inital_rgb[2])
color_tuple = np.vectorize(sns.utils.colorsys.hls_to_rgb)(hls[0], density / max(density) * 0.85 + 0.1, hls[2])
color = np.stack(color_tuple, -1)
res = axes.scatter(x, y, c=color, **kwargs)
if xlog:
axes.set_xscale('log')
if ylog:
axes.set_yscale('log')
if diagonal:
dmin = np.max(axes.axis()[::2])
dmax = np.min(axes.axis()[1::2])
xdiag = np.logspace(np.log10(dmin), np.log10(dmax), 301) if xlog else np.linspace(dmin, dmax, 301)
ydiag = np.logspace(np.log10(dmin), np.log10(dmax), 301) if ylog else np.linspace(dmin, dmax, 301)
axes.plot(xdiag, ydiag, 'k-', scalex=False, scaley=False)
return res
def _grid_shape(nsubplots):
"""
Gets a decent default shape for `nsubplots` sub-plots
:param nsubplots: number of plots to fit in a grid
:return: (nrows, ncols) tuple
"""
if nsubplots <= 3:
return (1, nsubplots)
for ncols in range(int(np.floor(np.sqrt(nsubplots))), 1, -1):
if nsubplots % ncols == 0:
nrows = nsubplots // ncols
if nrows <= 3:
return ncols, nrows
return nrows, ncols
return _grid_shape(nsubplots + 1)
[docs]def default_grid(nsubplots, subplot_spec=None, **kwargs) -> GridSpec:
"""
Creates a default layout for identically sizes subplots
:param nsubplots: number of subplots
:param subplot_spec: which subplot the new subplots should be contained in
:param kwargs: additional parameters defining the spacing of the subplots
:return: new gridspec
"""
nrows, ncols = _grid_shape(nsubplots)
if subplot_spec is None:
return GridSpec(nrows, ncols, **kwargs)
else:
return GridSpecFromSubplotSpec(nrows, ncols, subplot_spec, **kwargs)