Source code for holodeck.plot

"""Plotting module.

Provides convenience methods for generating standard plots and components using `matplotlib`.

"""

import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import kalepy as kale

import holodeck as holo
from holodeck import utils, log
from holodeck.constants import MSOL, YR

FIGSIZE = 6
FONTSIZE = 13
GOLDEN_RATIO = (np.sqrt(5) - 1) / 2

mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.15
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams["font.family"] = "serif"
plt.rcParams["legend.handlelength"] = 1.5
plt.rcParams["lines.solid_capstyle"] = 'round'
# plt.rcParams["font.size"] = FONTSIZE
# plt.rcParams["legend.fontsize"] = FONTSIZE*0.8
# mpl.rcParams['xtick.labelsize'] = FONTSIZE*0.8
# mpl.rcParams['ytick.labelsize'] = FONTSIZE*0.8

LABEL_GW_FREQUENCY_YR = r"GW Frequency $[\mathrm{yr}^{-1}]$"
LABEL_GW_FREQUENCY_HZ = r"GW Frequency $[\mathrm{Hz}]$"
LABEL_GW_FREQUENCY_NHZ = r"GW Frequency $[\mathrm{nHz}]$"
LABEL_SEPARATION_PC = r"Binary Separation $[\mathrm{pc}]$"
LABEL_CHARACTERISTIC_STRAIN = r"GW Characteristic Strain"
LABEL_HARDENING_TIME = r"Hardening Time $[\mathrm{Gyr}]$"
LABEL_CLC0 = r"$C_\ell / C_0$"

PARAM_KEYS = {
    'hard_time': r"phenom $\tau_f$",
    'hard_gamma_inner': r"phenom $\nu_\mathrm{inner}$",
    'hard_gamma_outer': r"phenom $\nu_\mathrm{outer}$",
    'hard_gamma_rot' : r"phenom $\nu_{\mathrm{rot}}$",
    'gsmf_phi0': r"GSMF $\psi_0$",
    'gsmf_mchar0_log10': r"GSMF $m_{\psi,0}$",
    'gsmf_alpha0': r"GSMF $\alpha_{\psi,0}$",
    'gpf_zbeta': r"GPF $\beta_{p,z}$",
    'gpf_qgamma': r"GPF $\gamma_{p,0}$",
    'gmt_norm': r"GMT $T_0$",
    'gmt_zbeta': r"GMT $\beta_{t,z}$",
    'mmb_mamp_log10': r"MMB $\mu$",
    'mmb_plaw': r"MMB $\alpha_{\mu}$",
    'mmb_scatter_dex': r"MMB $\epsilon_{\mu}$",
}

LABEL_DPRATIO = r"$\langle N_\mathrm{SS} \rangle / \mathrm{DP}_\mathrm{BG}$"
LABEL_EVSS = r"$\langle N_\mathrm{SS} \rangle$"
LABEL_DPBG = r"$\mathrm{DP}_\mathrm{BG}$"

COLORS_MPL = plt.rcParams['axes.prop_cycle'].by_key()['color']


[docs] class MidpointNormalize(mpl.colors.Normalize): """ Normalise the colorbar so that diverging bars work there way either side from a prescribed midpoint value) e.g. im=ax1.imshow(array, norm=MidpointNormalize(midpoint=0.,vmin=-100, vmax=100)) """ def __init__(self, vmin=None, vmax=None, midpoint=0.0, clip=False): super().__init__(vmin, vmax, clip) self.midpoint = midpoint return def __call__(self, value, clip=None): x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] return np.ma.masked_array(np.interp(value, x, y), np.isnan(value)) def inverse(self, value): # x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] y, x = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))
[docs] class MidpointLogNormalize(mpl.colors.LogNorm): def __init__(self, vmin=None, vmax=None, midpoint=0.0, clip=False): super().__init__(vmin, vmax, clip) self.midpoint = midpoint return def __call__(self, value, clip=None): x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] vals = utils.interp(value, x, y, xlog=True, ylog=False) # return np.ma.masked_array(vals, np.isnan(value)) return vals def inverse(self, value): y, x = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] vals = utils.interp(value, x, y, xlog=False, ylog=True) # return np.ma.masked_array(vals, np.isnan(value)) return vals
def figax_single(height=None, **kwargs): mpl.style.use('default') # avoid dark backgrounds from dark theme vscode plt.rcParams['axes.grid'] = True plt.rcParams['grid.alpha'] = 0.15 plt.rcParams["mathtext.fontset"] = "cm" plt.rcParams["font.family"] = "serif" plt.rcParams["legend.handlelength"] = 1.5 plt.rcParams["lines.solid_capstyle"] = 'round' plt.rcParams["font.size"] = FONTSIZE plt.rcParams["legend.fontsize"] = FONTSIZE*0.8 mpl.rcParams['xtick.labelsize'] = FONTSIZE*0.8 mpl.rcParams['ytick.labelsize'] = FONTSIZE*0.8 if height is None: height = FIGSIZE * GOLDEN_RATIO figsize_single = [FIGSIZE, height] adjust_single = dict(left=0.15, bottom=0.15, right=0.95, top=0.95) kwargs.setdefault('figsize', figsize_single) for kk, vv in adjust_single.items(): kwargs.setdefault(kk, vv) return figax(**kwargs) def figax_double(height=None, **kwargs): mpl.style.use('default') # avoid dark backgrounds from dark theme vscode plt.rcParams['axes.grid'] = True plt.rcParams['grid.alpha'] = 0.15 plt.rcParams["mathtext.fontset"] = "cm" plt.rcParams["font.family"] = "serif" plt.rcParams["legend.handlelength"] = 1.5 plt.rcParams["lines.solid_capstyle"] = 'round' plt.rcParams["font.size"] = FONTSIZE plt.rcParams["legend.fontsize"] = FONTSIZE*0.8 mpl.rcParams['xtick.labelsize'] = FONTSIZE*0.8 mpl.rcParams['ytick.labelsize'] = FONTSIZE*0.8 if height is None: height = 2 * FIGSIZE * GOLDEN_RATIO figsize_double = [2*FIGSIZE, height] adjust_double = dict(left=0.10, bottom=0.10, right=0.98, top=0.95) kwargs.setdefault('figsize', figsize_double) for kk, vv in adjust_double.items(): kwargs.setdefault(kk, vv) return figax(**kwargs)
[docs] def figax(figsize=[7, 5], ncols=1, nrows=1, sharex=False, sharey=False, squeeze=True, scale=None, xscale='log', xlabel='', xlim=None, yscale='log', ylabel='', ylim=None, left=None, bottom=None, right=None, top=None, hspace=None, wspace=None, widths=None, heights=None, grid=True, **kwargs): """Create matplotlib figure and axes instances. Convenience function to create fig/axes using `plt.subplots`, and quickly modify standard parameters. Parameters ---------- figsize : (2,) list, optional Figure size in inches. ncols : int, optional Number of columns of axes. nrows : int, optional Number of rows of axes. sharex : bool, optional Share xaxes configuration between axes. sharey : bool, optional Share yaxes configuration between axes. squeeze : bool, optional Remove dimensions of length (1,) in the `axes` object. scale : [type], optional Axes scaling to be applied to all x/y axes. One of ['log', 'lin']. xscale : str, optional Axes scaling for xaxes ['log', 'lin']. xlabel : str, optional Label for xaxes. xlim : [type], optional Limits for xaxes. yscale : str, optional Axes scaling for yaxes ['log', 'lin']. ylabel : str, optional Label for yaxes. ylim : [type], optional Limits for yaxes. left : [type], optional Left edge of axes space, set using `plt.subplots_adjust()`, as a fraction of figure. bottom : [type], optional Bottom edge of axes space, set using `plt.subplots_adjust()`, as a fraction of figure. right : [type], optional Right edge of axes space, set using `plt.subplots_adjust()`, as a fraction of figure. top : [type], optional Top edge of axes space, set using `plt.subplots_adjust()`, as a fraction of figure. hspace : [type], optional Height space between axes if multiple rows are being used. wspace : [type], optional Width space between axes if multiple columns are being used. widths : [type], optional heights : [type], optional grid : bool, optional Add grid lines to axes. Returns ------- fig : `matplotlib.figure.Figure` New matplotlib figure instance containing axes. axes : [ndarray] `matplotlib.axes.Axes` New matplotlib axes, either a single instance or an ndarray of axes. """ if scale is not None: xscale = scale yscale = scale scales = [xscale, yscale] for ii in range(2): if scales[ii].startswith('lin'): scales[ii] = 'linear' xscale, yscale = scales if (widths is not None) or (heights is not None): gridspec_kw = dict() if widths is not None: gridspec_kw['width_ratios'] = widths if heights is not None: gridspec_kw['height_ratios'] = heights kwargs['gridspec_kw'] = gridspec_kw fig, axes = plt.subplots(figsize=figsize, squeeze=False, ncols=ncols, nrows=nrows, sharex=sharex, sharey=sharey, **kwargs) plt.subplots_adjust( left=left, bottom=bottom, right=right, top=top, hspace=hspace, wspace=wspace) if ylim is not None: shape = (nrows, ncols, 2) if np.shape(ylim) == (2,): ylim = np.array(ylim)[np.newaxis, np.newaxis, :] else: shape = (nrows, ncols,) ylim = np.broadcast_to(ylim, shape) if xlim is not None: shape = (nrows, ncols, 2) if np.shape(xlim) == (2,): xlim = np.array(xlim)[np.newaxis, np.newaxis, :] else: shape = (nrows, ncols) xlim = np.broadcast_to(xlim, shape) _, xscale, xlabel = np.broadcast_arrays(axes, xscale, xlabel) _, yscale, ylabel = np.broadcast_arrays(axes, yscale, ylabel) for idx, ax in np.ndenumerate(axes): ax.set(xscale=xscale[idx], xlabel=xlabel[idx], yscale=yscale[idx], ylabel=ylabel[idx]) if xlim[idx] is not None: ax.set_xlim(xlim[idx]) if ylim[idx] is not None: ax.set_ylim(ylim[idx]) if grid is True: ax.set_axisbelow(True) # ax.grid(True, which='major', axis='both', c='0.6', zorder=2, alpha=0.4) # ax.grid(True, which='minor', axis='both', c='0.8', zorder=2, alpha=0.4) # ax.grid(True, which='major', axis='both', c='0.6', zorder=2, alpha=0.4) # ax.grid(True, which='minor', axis='both', c='0.8', zorder=2, alpha=0.4) if squeeze: axes = np.squeeze(axes) if np.ndim(axes) == 0: axes = axes[()] return fig, axes
[docs] def smap(args=[0.0, 1.0], cmap=None, log=False, norm=None, midpoint=None, under='0.8', over='0.8', left=None, right=None): """Create a colormap from a scalar range to a set of colors. Parameters ---------- args : scalar or array_like of scalar Range of valid scalar values to normalize with cmap : None, str, or ``matplotlib.colors.Colormap`` object Colormap to use. log : bool Logarithmic scaling norm : None or `matplotlib.colors.Normalize` Normalization to use. under : str or `None` Color specification for values below range. over : str or `None` Color specification for values above range. left : float {0.0, 1.0} or `None` Truncate the left edge of the colormap to this value. If `None`, 0.0 used (if `right` is provided). right : float {0.0, 1.0} or `None` Truncate the right edge of the colormap to this value If `None`, 1.0 used (if `left` is provided). Returns ------- smap : ``matplotlib.cm.ScalarMappable`` Scalar mappable object which contains the members: `norm`, `cmap`, and the function `to_rgba`. """ # _DEF_CMAP = 'viridis' _DEF_CMAP = 'Spectral' if cmap is None: if midpoint is not None: cmap = 'bwr' else: cmap = _DEF_CMAP cmap = _get_cmap(cmap) # Select a truncated subsection of the colormap if (left is not None) or (right is not None): if left is None: left = 0.0 if right is None: right = 1.0 cmap = _cut_cmap(cmap, left, right) if under is not None: cmap.set_under(under) if over is not None: cmap.set_over(over) if norm is None: norm = _get_norm(args, midpoint=midpoint, log=log) else: log = isinstance(norm, mpl.colors.LogNorm) # Create scalar-mappable smap = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) # Bug-Fix something something smap._A = [] # Allow `smap` to be used to construct colorbars smap.set_array([]) # Store type of mapping smap.log = log return smap
def _get_norm(data, midpoint=None, log=False): """ """ # Determine minimum and maximum if np.size(data) == 1: min = 0 max = np.int(data) - 1 elif np.size(data) == 2: min, max = data else: try: min, max = utils.minmax(data, filter=log) except Exception: err = f"Input `data` ({type(data)}) must be an integer, (2,) of scalar, or ndarray of scalar!" log.exception(err) raise ValueError(err) # print(f"{min=}, {max=}") # Create normalization if log: if midpoint is None: norm = mpl.colors.LogNorm(vmin=min, vmax=max) else: norm = MidpointLogNormalize(vmin=min, vmax=max, midpoint=midpoint) else: if midpoint is None: norm = mpl.colors.Normalize(vmin=min, vmax=max) else: # norm = MidpointNormalize(vmin=min, vmax=max, midpoint=midpoint) norm = MidpointNormalize(vmin=min, vmax=max, midpoint=midpoint) # norm = mpl.colors.TwoSlopeNorm(vmin=min, vcenter=midpoint, vmax=max) return norm def _cut_cmap(cmap, min=0.0, max=1.0, n=100): """Select a truncated subset of the given colormap. Code from: http://stackoverflow.com/a/18926541/230468 """ name = f"trunc({cmap.name},{min:.2f},{max:.2f})" new_cmap = mpl.colors.LinearSegmentedColormap.from_list(name, cmap(np.linspace(min, max, n))) return new_cmap def _get_cmap(cmap): """Retrieve a colormap with the given name if it is not already a colormap. """ if isinstance(cmap, mpl.colors.Colormap): return cmap try: return mpl.cm.get_cmap(cmap).copy() except Exception as err: log.error(f"Could not load colormap from `{cmap}` : {err}") raise def _get_hist_steps(xx, yy, yfilter=None): """Convert from bin-edges and histogram heights, to specifications for step lines. Parameters ---------- xx : array_like Independence variable representing bin-edges. Size (N,) yy : array_like Dependence variable representing histogram amplitudes. Size (N-1,) yfilter : None, bool, callable Returns ------- xnew : array (N,) x-values ynew : array (N,) y-values """ size = len(xx) - 1 if size != len(yy): err = f"Length of `xx` ({len(xx)}) should be length of `yy` ({len(yy)}) + 1!" log.exception(err) raise ValueError(err) xnew = [[xx[ii], xx[ii+1]] for ii in range(xx.size-1)] ynew = [[yy[ii], yy[ii]] for ii in range(xx.size-1)] xnew = np.array(xnew).flatten() ynew = np.array(ynew).flatten() if yfilter not in [None, False]: if yfilter is True: idx = (ynew > 0.0) elif callable(yfilter): idx = yfilter(ynew) else: raise ValueError() xnew = xnew[idx] ynew = ynew[idx] return xnew, ynew def draw_hist_steps(ax, xx, yy, yfilter=None, **kwargs): return ax.plot(*_get_hist_steps(xx, yy, yfilter=yfilter), **kwargs) def draw_gwb(ax, xx, gwb, nsamp=10, color=None, label=None, **kwargs): if color is None: color = ax._get_lines.get_next_color() kw_plot = kwargs.pop('plot', {}) kw_plot.setdefault('color', color) hh = draw_med_conf(ax, xx, gwb, plot=kw_plot, **kwargs) if (nsamp is not None) and (nsamp > 0): nsamp_max = gwb.shape[1] idx = np.random.choice(nsamp_max, np.min([nsamp, nsamp_max]), replace=False) for ii in idx: ax.plot(xx, gwb[:, ii], color=color, alpha=0.25, lw=1.0, ls='-') return hh def draw_ss_and_gwb(ax, xx, hc_ss, gwb, nsamp=10, color=None, cmap = cm.rainbow, sslabel=None, bglabel=None, **kwargs): if color is None: color = ax._get_lines.get_next_color() kw_plot = kwargs.get('plot', {}) kw_plot.setdefault('color', color) # hh = draw_med_conf(ax, xx, gwb, plot=kw_plot, **kwargs) if (nsamp is not None) and (nsamp > 0): nsamp_max = gwb.shape[1] nsize = np.min([nsamp, nsamp_max]) colors = cmap(np.linspace(0,1,nsize)) ci = 0 idx = np.random.choice(nsamp_max, nsize, replace=False) for ii in idx: if(ii==0): label=bglabel else: label=None cc = colors[ci] if color is None else color ax.plot(xx, gwb[:, ii], color=cc, alpha=0.25, lw=1.0, ls='-') for ll in range(len(hc_ss[0,0])): if(ll==0): edgecolor='k' if(ii==0): label=sslabel # first source of first realization else: label=None else: edgecolor=None label=None ax.scatter(xx, hc_ss[:, ii, ll], color=cc, alpha=0.25, edgecolor=edgecolor, label=label) ci+=1 # return hh def plot_gwb(fobs, gwb, hc_ss=None, bglabel=None, sslabel=None, **kwargs): xx = fobs * YR fig, ax = figax( xlabel=LABEL_GW_FREQUENCY_YR, ylabel=LABEL_CHARACTERISTIC_STRAIN ) if(hc_ss is not None): draw_ss_and_gwb(ax, xx, hc_ss, gwb, sslabel=sslabel, bglabel=bglabel, **kwargs) else: draw_gwb(ax, xx, gwb, **kwargs) _twin_hz(ax) return fig
[docs] def plot_bg_ss(fobs, bg, ss=None, bglabel=None, sslabel=None, xlabel=LABEL_GW_FREQUENCY_YR, ylabel=LABEL_CHARACTERISTIC_STRAIN, **kwargs): """ Can plot strain or power spectral density, just need to set ylabel accordingly """ xx = fobs * YR fig, ax = figax( xlabel=xlabel, ylabel=ylabel ) if(ss is not None): draw_ss_and_gwb(ax, xx, ss, bg, sslabel=sslabel, bglabel=bglabel, **kwargs) else: draw_gwb(ax, xx, bg, **kwargs) _twin_hz(ax) return fig
def draw_sspars_and_bgpars(axs, xx, sspar, bgpar, nsamp=10, cmap=cm.rainbow_r, color = None, label=None, **kwargs): # if color is None: # color = axs[0,0]._get_lines.get_next_color() # kw_plot = kwargs.get('plot', {}) # kw_plot.setdefault('color', color) m_bg = bgpar[0,:,:]/MSOL # bg avg masses in solar masses m_ss = sspar[0,:,:,:]/MSOL # ss masses in solar masses # mm_med = draw_med_conf(axs[0,0], xx, m_bg, plot=kw_plot, **kwargs) q_bg = bgpar[1,:,:] # bg avg ratios q_ss = sspar[1,:,:,:] # ss ratios # qq_med = draw_med_conf(axs[0,1], xx, q_bg, plot=kw_plot, **kwargs) di_bg = holo.cosmo.comoving_distance(bgpar[2,:,:]).value # bg avg distances in Mpc di_ss = holo.cosmo.comoving_distance(sspar[2,:,:,:]).value # ss distances in Mpc df_bg = holo.cosmo.comoving_distance(bgpar[3,:,:]).value # bg avg distances in Mpc df_ss = holo.cosmo.comoving_distance(sspar[3,:,:,:]).value # ss distances in Mpc # dd_med = draw_med_conf(axs[1,0], xx, d_bg, plot=kw_plot, **kwargs) # hh_med = draw_med_conf(axs[1,1], xx, hc_bg, plot=kw_plot, **kwargs) if (nsamp is not None) and (nsamp > 0): nsamp_max = bgpar.shape[2] nsize = np.min([nsamp, nsamp_max]) colors = cmap(np.linspace(0,1,nsize)) ci = 0 idx = np.random.choice(nsamp_max, nsize, replace=False) for ii in idx: # background axs[0,0].plot(xx, m_bg[:,ii], color=colors[ci], alpha=0.25, lw=1.0, ls='-') # masses (upper left) axs[0,1].plot(xx, q_bg[:,ii], color=colors[ci], alpha=0.25, lw=1.0, ls='-') # ratios (upper right) axs[1,0].plot(xx, di_bg[:,ii], color=colors[ci], alpha=0.25, lw=1.0, ls='-') # initial distances (lower left) axs[1,1].plot(xx, df_bg[:, ii], color=colors[ci], alpha=0.25, lw=1.0, ls='-') # final distances (lower right) # single sources for ll in range(sspar.shape[-1]): if(ll==0): edgecolor='k' else: edgecolor=None axs[0,0].scatter(xx, m_ss[:, ii, ll], color=colors[ci], alpha=0.25, edgecolor=edgecolor) # ss masses (upper left) axs[0,1].scatter(xx, q_ss[:, ii, ll], color=colors[ci], alpha=0.25, edgecolor=edgecolor) # ss ratios (upper right) axs[1,0].scatter(xx, di_ss[:, ii, ll], color=colors[ci], alpha=0.25, edgecolor=edgecolor) # ss intial distances (lower left) axs[1,1].scatter(xx, df_ss[:, ii, ll], color=colors[ci], alpha=0.25, edgecolor=edgecolor) # ss final distances (lower left) ci +=1 # return mm_med, qq_med, dd_med, hh_med def plot_pars(fobs, sspar, bgpar, **kwargs): xx= fobs * YR fig, axs = figax(figsize = (11,6), ncols=2, nrows=2, sharex = True) axs[0,0].set_ylabel('Total Mass $M/M_\odot$') axs[0,1].set_ylabel('Mass Ratio $q$') axs[1,0].set_ylabel('Initial Comoving Distance $d_c$ (Mpc)') axs[1,1].set_ylabel('Final Comoving Distance $d_c$ (Mpc)') axs[1,0].set_xlabel(LABEL_GW_FREQUENCY_YR) axs[1,1].set_xlabel(LABEL_GW_FREQUENCY_YR) draw_sspars_and_bgpars(axs, xx, sspar, bgpar, color='pink') fig.tight_layout() return fig
[docs] def scientific_notation(val, man=1, exp=0, dollar=True): """Convert a scalar into a string with scientific notation (latex formatted). Arguments --------- val : scalar Numerical value to convert. man : int or `None` Precision of the mantissa (decimal points); or `None` for omit mantissa. exp : int or `None` Precision of the exponent (decimal points); or `None` for omit exponent. dollar : bool Include dollar-signs ('$') around returned expression. Returns ------- rv_str : str Scientific notation string using latex formatting. """ if val == 0.0: rv_str = "$"*dollar + "0.0" + "$"*dollar return rv_str # get log10 exponent val_exp = np.floor(np.log10(np.fabs(val))) # get mantissa (positive/negative is still included here) val_man = val / np.power(10.0, val_exp) val_man = np.around(val_man, man) if val_man >= 10.0: val_man /= 10.0 val_exp += 1 # Construct Mantissa String # -------------------------------- str_man = "{0:.{1:d}f}".format(val_man, man) # If the mantissa is '1' (or '1.0' or '1.00' etc), dont write it if str_man == "{0:.{1:d}f}".format(1.0, man): str_man = "" # Construct Exponent String # -------------------------------- str_exp = "10^{{ {:d} }}".format(int(val_exp)) # Put them together # -------------------------------- rv_str = "$"*dollar + str_man if len(str_man) and len(str_exp): rv_str += " \\times" rv_str += str_exp + "$"*dollar return rv_str
def _draw_plaw(ax, freqs, amp=1e-15, f0=1/YR, **kwargs): kwargs.setdefault('alpha', 0.25) kwargs.setdefault('color', 'k') kwargs.setdefault('ls', '--') plaw = amp * np.power(np.asarray(freqs)/f0, -2/3) return ax.plot(freqs, plaw, **kwargs) def _twin_hz(ax, nano=True, fs=10, **kw): tw = ax.twiny() tw.grid(False) xlim = np.array(ax.get_xlim()) / YR if nano: label = LABEL_GW_FREQUENCY_NHZ xlim *= 1e9 else: label = LABEL_GW_FREQUENCY_HZ tw.set(xlim=xlim, xscale=ax.get_xscale()) tw.set_xlabel(label, fontsize=fs, **kw) return tw def _twin_yr(ax, nano=True, fs=10, label=True, **kw): tw = ax.twiny() tw.grid(False) xlim = np.array(ax.get_xlim()) * YR if nano: xlim /= 1e9 tw.set(xlim=xlim, xscale=ax.get_xscale()) if label: tw.set_xlabel(LABEL_GW_FREQUENCY_YR, fontsize=fs, **kw) return tw def draw_med_conf(ax, xx, vals, fracs=[0.50, 0.90], weights=None, plot={}, fill={}, filter=False): plot.setdefault('alpha', 0.75) fill.setdefault('alpha', 0.2) percs = np.atleast_1d(fracs) assert np.all((0.0 <= percs) & (percs <= 1.0)) # center the target percentages into pairs around 50%, e.g. 68 ==> [16,84] inter_percs = [[0.5-pp/2, 0.5+pp/2] for pp in percs] # Add the median value (50%) inter_percs = [0.5, ] + np.concatenate(inter_percs).tolist() # Get percentiles; they go along the last axis if filter: rv = [ kale.utils.quantiles(vv[vv > 0.0], percs=inter_percs, weights=weights) for vv in vals ] rv = np.asarray(rv) else: rv = kale.utils.quantiles(vals, percs=inter_percs, weights=weights, axis=-1) med, *conf = rv.T # plot median hh, = ax.plot(xx, med, **plot) # Reshape confidence intervals to nice plotting shape # 2*P, X ==> (P, 2, X) conf = np.array(conf).reshape(len(percs), 2, xx.size) kw = dict(color=hh.get_color()) kw.update(fill) fill = kw # plot each confidence interval for lo, hi in conf: gg = ax.fill_between(xx, lo, hi, **fill) return (hh, gg) def draw_med_conf_color(ax, xx, vals, fracs=[0.50, 0.90], weights=None, plot={}, fill={}, filter=False, color=None, linestyle='-'): plot.setdefault('alpha', 0.75) fill.setdefault('alpha', 0.2) percs = np.atleast_1d(fracs) assert np.all((0.0 <= percs) & (percs <= 1.0)) # center the target percentages into pairs around 50%, e.g. 68 ==> [16,84] inter_percs = [[0.5-pp/2, 0.5+pp/2] for pp in percs] # Add the median value (50%) inter_percs = [0.5, ] + np.concatenate(inter_percs).tolist() # Get percentiles; they go along the last axis if filter: rv = [ kale.utils.quantiles(vv[vv > 0.0], percs=inter_percs, weights=weights) for vv in vals ] rv = np.asarray(rv) else: rv = kale.utils.quantiles(vals, percs=inter_percs, weights=weights, axis=-1) med, *conf = rv.T # plot median if color is not None: hh, = ax.plot(xx, med, color=color, linestyle=linestyle, **plot) else: hh, = ax.plot(xx, med, **plot) # Reshape confidence intervals to nice plotting shape # 2*P, X ==> (P, 2, X) conf = np.array(conf).reshape(len(percs), 2, xx.size) kw = dict(color=hh.get_color()) kw.update(fill) fill = kw # plot each confidence interval for lo, hi in conf: gg = ax.fill_between(xx, lo, hi, **fill) return (hh, gg) def smooth_spectra(xx, gwb, smooth=(20, 4), interp=100): assert np.shape(xx) == (np.shape(gwb)[0],) if len(smooth) != 2: err = f"{smooth=} must be a (2,) of float specifying the filter-window size and polynomial-order!!" raise ValueError(err) xnew = kale.utils.spacing(xx, 'log', num=int(interp)) rv = [utils.interp(xnew, xx, vv) for vv in gwb.T] rv = sp.signal.savgol_filter(rv, *smooth, axis=-1) med, *conf = rv # Reshape confidence intervals to nice plotting shape # 2*P, X ==> (P, 2, X) npercs = np.shape(conf)[0] // 2 conf = np.array(conf).reshape(npercs, 2, xnew.size) return xnew, med, conf def get_med_conf(vals, fracs, weights=None, axis=-1): percs = np.atleast_1d(fracs) assert np.all((0.0 <= percs) & (percs <= 1.0)) # center the target percentages into pairs around 50%, e.g. 68 ==> [16,84] inter_percs = [[0.5-pp/2, 0.5+pp/2] for pp in percs] # Add the median value (50%) inter_percs = [0.5, ] + np.concatenate(inter_percs).tolist() # Get percentiles; they go along the last axis rv = kale.utils.quantiles(vals, percs=inter_percs, weights=weights, axis=axis) return rv def draw_smooth_med_conf(ax, xx, vals, smooth=(10, 4), interp=100, fracs=[0.50, 0.90], weights=None, plot={}, fill={}): plot.setdefault('alpha', 0.5) fill.setdefault('alpha', 0.2) rv = get_med_conf(vals, fracs, weights, axis=-1) xnew, med, conf = smooth_spectra(xx, rv, smooth=smooth, interp=interp) # plot median hh, = ax.plot(xnew, med, **plot) # plot each confidence interval for lo, hi in conf: gg = ax.fill_between(xnew, lo, hi, color=hh.get_color(), **fill) return (hh, gg) def violins(ax, xx, yy, zz, width, **kwargs): assert np.ndim(xx) == 1 if np.ndim(yy) == 1: yy = [yy] * len(xx) assert np.ndim(yy) == 2 assert np.shape(yy) == np.shape(zz) if np.shape(yy)[0] != xx.size: if np.shape(yy)[1] == xx.size: yy = yy.T zz = zz.T assert np.shape(xx)[0] == xx.size assert np.shape(zz)[0] == xx.size for ii in range(xx.size): usey = yy[ii] usez = zz[ii] handle = violin(ax, xx[ii], usey, usez, width, **kwargs) return handle def violin(ax, xx, yy, zz, width, median_log10=False, side='both', clip_pdf=None, median={}, line={}, fill={}, **kwargs): assert np.ndim(xx) == 0 assert np.shape(xx) == np.shape(width) assert np.ndim(yy) == 1 assert yy.shape == zz.shape valid_sides = ['l', 'r', 'b'] if side[0] not in valid_sides: raise ValueError(f"{side=} must begin with one of {valid_sides}!") if line is not None: line_def = dict(alpha=0.5, lw=0.5, color='k') line_def.update(kwargs) line_def.update(line) line = line_def if fill is not None: fill_def = dict(alpha=0.25, lw=0.0) fill_def.update(kwargs) fill_def.update(fill) fill = fill_def if clip_pdf is not None: assert np.ndim(clip_pdf) == 0 assert clip_pdf < 1.0 zz = zz / zz.max() if median is True: median = {} if median is False: median = None if median is not None: if median_log10: dy = np.diff(np.log10(yy)) else: dy = np.diff(yy) cdf = 0.5 * (zz[1:] + zz[:-1]) * dy cdf = np.concatenate([[0.0, ], cdf]) cdf = np.cumsum(cdf) med = np.interp([0.5], cdf/cdf.max(), yy) if clip_pdf is not None: idx = zz > clip_pdf yy = yy[idx] zz = zz[idx] xl = xx * np.ones_like(yy) xr = xx * np.ones_like(yy) left_flag = side.startswith('l') or side.startswith('b') right_flag = side.startswith('r') or side.startswith('b') if left_flag: xl = xl - zz * width if right_flag: xr = xr + zz * width handle = [] if line is not None: h1, = ax.plot(xl, yy, **line) ax.plot(xr, yy, **line) handle.append(h1) if fill is not None: h2 = ax.fill_betweenx(yy, xl, xr, **fill) handle.append(h2) if median is not None: kwargs = dict(line) kwargs['lw'] = 1.0 kwargs.update(median) mwid = kwargs.pop('width', 0.5) ll = xx rr = xx if left_flag: ll = ll - width * mwid if right_flag: rr = rr + width * mwid ax.plot([ll, rr], [med, med], **kwargs) handle = handle[0] if len(handle) == 1 else tuple(handle) return handle class Corner: _LIMITS_STRETCH = 0.1 def __init__(self, ndim, origin='tl', rotate=True, axes=None, labels=None, limits=None, **kwargs): origin = kale.plot._parse_origin(origin) # -- Construct figure and axes if axes is None: fig, axes = kale.plot._figax(ndim, **kwargs) self.fig = fig if origin[0] == 1: axes = axes[::-1] if origin[1] == 1: axes = axes.T[::-1].T else: try: self.fig = axes[0, 0].figure except Exception as err: raise err self.origin = origin self.axes = axes last = ndim - 1 if labels is None: labels = [''] * ndim for (ii, jj), ax in np.ndenumerate(axes): # Set upper-right plots to invisible if jj > ii: ax.set_visible(False) continue ax.grid(True) # Bottom row if ii == last: if rotate and (jj == last): ax.set_ylabel(labels[jj]) # currently this is being reset to empty later, that's okay else: ax.set_xlabel(labels[jj]) # If vertical origin is the top if origin[0] == 1: ax.xaxis.set_label_position('top') ax.xaxis.set_ticks_position('top') # Non-bottom row else: ax.set_xlabel('') for tlab in ax.xaxis.get_ticklabels(): tlab.set_visible(False) # First column if jj == 0: # Not-first rows if ii != 0: ax.set_ylabel(labels[ii]) # If horizontal origin is the right if origin[1] == 1: ax.yaxis.set_label_position('right') ax.yaxis.set_ticks_position('right') # Not-first columns else: # if (jj != last) or (not rotate): ax.set_ylabel('') for tlab in ax.yaxis.get_ticklabels(): tlab.set_visible(False) # Diagonals if ii == jj: # not top-left if (ii != 0) and (origin[1] == 0): ax.yaxis.set_label_position('right') ax.yaxis.set_ticks_position('right') else: ax.yaxis.set_label_position('left') ax.yaxis.set_ticks_position('left') # If axes limits are given, set axes to them if limits is not None: limit_flag = False kale.plot._set_corner_axes_extrema(self.axes, limits, rotate) # Otherwise, prepare to calculate limits during plotting else: limits = [None] * ndim limit_flag = True # --- Store key parameters self.ndim = ndim self._rotate = rotate self._labels = labels self._limits = limits self._limit_flag = limit_flag return def plot(self, data, edges=None, weights=None, ratio=None, quantiles=None, sigmas=None, reflect=None, color=None, cmap=None, limit=None, dist1d={}, dist2d={}): if limit is None: limit = self._limit_flag # ---- Sanitize if np.ndim(data) != 2: err = "`data` (shape: {}) must be 2D with shape (parameters, data-points)!".format( np.shape(data)) raise ValueError(err) axes = self.axes size = np.shape(data)[0] shp = np.shape(axes) if (np.ndim(axes) != 2) or (shp[0] != shp[1]) or (shp[0] != size): raise ValueError("`axes` (shape: {}) does not match data dimension {}!".format(shp, size)) if ratio is not None: if np.ndim(ratio) != 2 or np.shape(ratio)[0] != size: err = f"`ratio` (shape: {np.shape(ratio)}) must be 2D with shape (parameters, data-points)!" raise ValueError(err) # ---- Set parameters last = size - 1 rotate = self._rotate # Set default color or cmap as needed color, cmap = kale.plot._parse_color_cmap(ax=axes[0][0], color=color, cmap=cmap) edges = kale.utils.parse_edges(data, edges=edges) quantiles, _ = kale.plot._default_quantiles(quantiles=quantiles, sigmas=sigmas) # ---- Draw 1D Histograms & Carpets limits = [None] * size # variable to store the data extrema for jj, ax in enumerate(axes.diagonal()): rot = (rotate and (jj == last)) refl = reflect[jj] if reflect is not None else None rat = ratio[jj] if ratio is not None else None self.dist1d( ax, edges[jj], data[jj], weights=weights, ratio=rat, quantiles=quantiles, rotate=rot, color=color, reflect=refl, **dist1d ) limits[jj] = kale.utils.minmax(data[jj], stretch=self._LIMITS_STRETCH) # ---- Draw 2D Histograms and Contours for (ii, jj), ax in np.ndenumerate(axes): if jj >= ii: continue rat = [ratio[jj], ratio[ii]] if ratio is not None else None handle = self.dist2d( ax, [edges[jj], edges[ii]], [data[jj], data[ii]], weights=weights, ratio=rat, color=color, cmap=cmap, quantiles=quantiles, **dist2d ) # ---- calculate and set axes limits if limit: # Update any stored values for ii in range(self.ndim): self._limits[ii] = kale.utils.minmax(limits[ii], prev=self._limits[ii]) # Set axes to limits kale.plot._set_corner_axes_extrema(self.axes, self._limits, self._rotate) return handle def dist1d(self, ax, edges, data, color=None, weights=None, ratio=None, probability=True, rotate=False, density=None, confidence=False, hist=None, carpet=True, quantiles=None, ls=None, alpha=None, reflect=None, **kwargs): if np.ndim(data) != 1: err = "Input `data` (shape: {}) is not 1D!".format(np.shape(data)) raise ValueError(err) if ratio is not None and np.ndim(ratio) != 1: err = "`ratio` (shape: {}) is not 1D!".format(np.shape(ratio)) raise ValueError(err) # Use `scatter` as the limiting-number of scatter-points # To disable scatter, `scatter` will be set to `None` carpet = kale.plot._scatter_limit(carpet, "carpet") # set default color to next from axes' color-cycle if color is None: color = kale.plot._get_next_color(ax) # ---- Draw Components # Draw PDF from KDE handle = None # variable to store a plotting 'handle' from one of the plotted objects if density is not False: kde = kale.KDE(data, weights=weights) # If histogram is also being plotted (as a solid line) use a dashed line if ls is None: _ls = '--' if hist else '-' _alpha = 0.8 if hist else 0.8 else: _ls = ls _alpha = alpha # Calculate KDE density distribution for the given parameter kde_kwargs = dict(probability=probability, params=0, reflect=reflect) xx, yy = kde.density(**kde_kwargs) if ratio is not None: kde_ratio = kale.KDE(ratio, weights=weights) _, kde_ratio = kde_ratio.density(points=xx, **kde_kwargs) yy /= kde_ratio # rescale by value of density yy = yy * density # Plot if rotate: temp = xx xx = yy yy = temp handle, = ax.plot(xx, yy, color=color, ls=_ls, alpha=_alpha, **kwargs) # Draw Histogram if hist: if alpha is None: _alpha = 0.5 if density else 0.8 else: _alpha = alpha _, _, hh = self.hist1d( ax, data, edges=edges, weights=weights, ratio=ratio, color=color, density=True, probability=probability, joints=True, rotate=rotate, ls=ls, alpha=_alpha, **kwargs ) if handle is None: handle = hh # Draw Contours and Median Line if confidence: if ratio is not None: raise NotImplementedError("`confidence` with `ratio` is not implemented!") hh = kale.plot._confidence(data, ax=ax, color=color, quantiles=quantiles, rotate=rotate) if handle is None: handle = hh # Draw Carpet Plot if carpet is not None: if ratio is not None: raise NotImplementedError("`confidence` with `carpet` is not implemented!") hh = kale.plot._carpet(data, weights=weights, ax=ax, color=color, rotate=rotate, limit=carpet) if handle is None: handle = hh return handle def hist1d(self, ax, data, edges=None, weights=None, ratio=None, density=False, probability=False, renormalize=False, joints=True, positive=True, rotate=False, **kwargs): hist_kwargs = dict(density=density, probability=probability) # Calculate histogram hist, edges = kale.utils.histogram(data, bins=edges, weights=weights, **hist_kwargs) if ratio is not None: hist_ratio, _ = kale.utils.histogram(data, bins=edges, **hist_kwargs) hist /= hist_ratio # Draw rv = kale.plot.draw_hist1d( ax, edges, hist, renormalize=renormalize, joints=joints, positive=positive, rotate=rotate, **kwargs ) return hist, edges, rv def dist2d( self, ax, edges, data, weights=None, ratio=None, quantiles=None, sigmas=None, color=None, cmap=None, smooth=None, upsample=None, pad=True, outline=True, median=False, scatter=True, contour=True, hist=True, mask_dense=None, mask_below=True, mask_alpha=0.9 ): if np.ndim(data) != 2 or np.shape(data)[0] != 2: err = f"`data` (shape: {np.shape(data)}) must be 2D with shape (parameters, data-points)!" raise ValueError(err) if ratio is not None: if np.ndim(ratio) != 2 or np.shape(ratio)[0] != 2: err = f"`ratio` (shape: {np.shape(ratio)}) must be 2D with shape (parameters, data-points)!" raise ValueError(err) # Set default color or cmap as needed color, cmap = kale.plot._parse_color_cmap(ax=ax, color=color, cmap=cmap) # Use `scatter` as the limiting-number of scatter-points # To disable scatter, `scatter` will be set to `None` scatter = kale.plot._scatter_limit(scatter, "scatter") # Default: if either hist or contour is being plotted, mask over high-density scatter points if mask_dense is None: mask_dense = (scatter is not None) and (hist or contour) # Calculate histogram edges = kale.utils.parse_edges(data, edges=edges) hist_kwargs = dict(bins=edges, density=True) hh, *_ = np.histogram2d(*data, weights=weights, **hist_kwargs) if ratio is not None: hh_ratio, *_ = np.histogram2d(*ratio, **hist_kwargs) hh /= hh_ratio hh = np.nan_to_num(hh) _, levels, quantiles = kale.plot._dfm_levels(hh, quantiles=quantiles, sigmas=sigmas) if mask_below is True: mask_below = levels.min() handle = None # ---- Draw Scatter Points if (scatter is not None): handle = kale.plot.draw_scatter(ax, *data, color=color, zorder=5, limit=scatter) # ---- Draw Median Lines (cross-hairs style) if median: if ratio: raise NotImplementedError("`median` is not impemented with `ratio`!") for dd, func in zip(data, [ax.axvline, ax.axhline]): # Calculate value if weights is None: med = np.median(dd) else: med = kale.utils.quantiles(dd, percs=0.5, weights=weights) # Load path_effects out_pe = kale.plot._get_outline_effects() if outline else None # Draw func(med, color=color, alpha=0.25, lw=1.0, zorder=40, path_effects=out_pe) cents, hh_prep = kale.plot._prep_hist(edges, hh, smooth, upsample, pad) # ---- Draw 2D Histogram if hist: _ee, _hh, handle = kale.plot.draw_hist2d( ax, edges, hh, mask_below=mask_below, cmap=cmap, zorder=10, shading='auto', ) # ---- Draw Contours if contour: contour_cmap = cmap.reversed() # Narrow the range of contour colors relative to full `cmap` dd = 0.7 / 2 nq = len(quantiles) if nq < 4: dd = nq*0.08 contour_cmap = kale.plot._cut_colormap(contour_cmap, 0.5 - dd, 0.5 + dd) _ee, _hh, _handle = _contour2d( ax, cents, hh_prep, levels=levels, cmap=contour_cmap, zorder=20, outline=outline, ) # hi = 1 if len(_handle.collections) > 0 else 0 hi = -1 handle = _handle.collections[hi] # for some reason the above handle is not showing up on legends... create a dummy line # to get a better handle col = handle.get_edgecolor() handle, = ax.plot([], [], color=col) # Mask dense scatter-points if mask_dense: # NOTE: levels need to be recalculated here! _, levels, quantiles = kale.plot._dfm_levels(hh_prep, quantiles=quantiles) span = [levels.min(), hh_prep.max()] mask_cmap = mpl.colors.ListedColormap('white') # Draw ax.contourf(*cents, hh_prep, span, cmap=mask_cmap, antialiased=True, zorder=9, alpha=mask_alpha) return handle def legend(self, handles, labels, index=None, loc=None, fancybox=False, borderaxespad=0, **kwargs): """ """ fig = self.fig # Set Bounding Box Location # ------------------------------------ bbox = kwargs.pop('bbox', None) bbox = kwargs.pop('bbox_to_anchor', bbox) if bbox is None: if index is None: size = self.ndim if size in [2, 3, 4]: index = (0, -1) loc = 'lower left' elif size == 1: index = (0, 0) loc = 'upper right' elif size % 2 == 0: index = size // 2 index = (size - index - 2, index + 1) loc = 'lower left' else: index = (size // 2) + 1 loc = 'lower left' index = (size-index-1, index) bbox = self.axes[index].get_position() bbox = (bbox.x0, bbox.y0) kwargs['bbox_to_anchor'] = bbox kwargs.setdefault('bbox_transform', fig.transFigure) # Set other defaults leg = fig.legend(handles, labels, fancybox=fancybox, borderaxespad=borderaxespad, loc=loc, **kwargs) return leg def target(self, targets, upper_limits=None, lower_limits=None, lw=1.0, fill_alpha=0.1, **kwargs): size = self.ndim axes = self.axes last = size - 1 # labs = self._labels extr = self._limits # ---- check / sanitize arguments if len(targets) != size: err = "`targets` (shape: {}) must be shaped ({},)!".format(np.shape(targets), size) raise ValueError(err) if lower_limits is None: lower_limits = [None] * size if len(lower_limits) != size: err = "`lower_limits` (shape: {}) must be shaped ({},)!".format(np.shape(lower_limits), size) raise ValueError(err) if upper_limits is None: upper_limits = [None] * size if len(upper_limits) != size: err = "`upper_limits` (shape: {}) must be shaped ({},)!".format(np.shape(upper_limits), size) raise ValueError(err) # ---- configure settings kwargs.setdefault('color', 'red') kwargs.setdefault('alpha', 0.50) kwargs.setdefault('zorder', 20) line_kw = dict() line_kw.update(kwargs) line_kw['lw'] = lw span_kw = dict() span_kw.update(kwargs) span_kw['alpha'] = fill_alpha # ---- draw 1D targets and limits for jj, ax in enumerate(axes.diagonal()): if (self._rotate and (jj == last)): func = ax.axhline func_up = lambda xx: ax.axhspan(extr[jj][0], xx, **span_kw) func_lo = lambda xx: ax.axhspan(xx, extr[jj][1], **span_kw) else: func = ax.axvline func_up = lambda xx: ax.axvspan(extr[jj][0], xx, **span_kw) func_lo = lambda xx: ax.axvspan(xx, extr[jj][1], **span_kw) if targets[jj] is not None: func(targets[jj], **line_kw) if upper_limits[jj] is not None: func_up(upper_limits[jj]) if lower_limits[jj] is not None: func_lo(lower_limits[jj]) # ---- draw 2D targets and limits for (ii, jj), ax in np.ndenumerate(axes): if jj >= ii: continue for kk, func, func_lim in zip([ii, jj], [ax.axhline, ax.axvline], [ax.axhspan, ax.axvspan]): if targets[kk] is not None: func(targets[kk], **line_kw) if upper_limits[kk] is not None: func_lim(extr[kk][0], upper_limits[kk], **span_kw) if lower_limits[kk] is not None: func_lim(lower_limits[kk], extr[kk][0], **span_kw) return def _contour2d(ax, edges, hist, levels, outline=True, **kwargs): LW = 1.5 alpha = kwargs.setdefault('alpha', 0.8) lw = kwargs.pop('linewidths', kwargs.pop('lw', LW)) kwargs.setdefault('linestyles', kwargs.pop('ls', '-')) kwargs.setdefault('zorder', 10) # ---- Draw contours cont = ax.contour(*edges, hist, levels=levels, linewidths=lw, **kwargs) # ---- Add Outline path effect to contours if (outline is True): outline = kale.plot._get_outline_effects(2*lw, alpha=1 - np.sqrt(1 - alpha)) plt.setp(cont.collections, path_effects=outline) return edges, hist, cont
[docs] def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): ''' https://stackoverflow.com/a/18926541 ''' if isinstance(cmap, str): cmap = plt.get_cmap(cmap) new_cmap = mpl.colors.LinearSegmentedColormap.from_list( 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), cmap(np.linspace(minval, maxval, n))) return new_cmap
# ================================================================================================= # ==== Below Needs Review / Cleaning ==== # ================================================================================================= ''' def plot_bin_pop(pop): mt, mr = utils.mtmr_from_m1m2(pop.mass) redz = cosmo.a_to_z(pop.scafa) data = [mt/MSOL, mr, pop.sepa/PC, 1+redz] data = [np.log10(dd) for dd in data] reflect = [None, [None, 0], None, [0, None]] labels = [r'M/M_\odot', 'q', r'a/\mathrm{{pc}}', '1+z'] labels = [r'${{\log_{{10}}}} \left({}\right)$'.format(ll) for ll in labels] if pop.eccen is not None: data.append(pop.eccen) reflect.append([0.0, 1.0]) labels.append('e') kde = kale.KDE(data, reflect=reflect) corner = kale.Corner(kde, labels=labels, figsize=[8, 8]) corner.plot_data(kde) return corner def plot_mbh_scaling_relations(pop, fname=None, color='r'): units = r"$[\log_{10}(M/M_\odot)]$" fig, ax = plt.subplots(figsize=[8, 5]) ax.set(xlabel=f'Stellar Mass {units}', ylabel=f'BH Mass {units}') # ==== Plot McConnell+Ma-2013 Data ==== handles = [] names = [] if fname is not None: hh = _draw_MM2013_data(ax, fname) handles.append(hh) names.append('McConnell+Ma') # ==== Plot MBH Merger Data ==== hh, nn = _draw_pop_masses(ax, pop, color) handles = handles + hh names = names + nn ax.legend(handles, names) return fig def _draw_MM2013_data(ax): data = observations.load_mcconnell_ma_2013() data = {kk: data[kk] if kk == 'name' else np.log10(data[kk]) for kk in data.keys()} key = 'mbulge' mass = data['mass'] yy = mass[:, 1] yerr = np.array([yy - mass[:, 0], mass[:, 2] - yy]) vals = data[key] if np.ndim(vals) == 1: xx = vals xerr = None elif vals.shape[1] == 2: xx = vals[:, 0] xerr = vals[:, 1] elif vals.shape[1] == 3: xx = vals[:, 1] xerr = np.array([xx-vals[:, 0], vals[:, 2]-xx]) else: raise ValueError() idx = (xx > 0.0) & (yy > 0.0) if xerr is not None: xerr = xerr[:, idx] ax.errorbar(xx[idx], yy[idx], xerr=xerr, yerr=yerr[:, idx], fmt='none', zorder=10) handle = ax.scatter(xx[idx], yy[idx], zorder=10) ax.set(ylabel='MBH Mass', xlabel=key) return handle def _draw_pop_masses(ax, pop, color='r', nplot=3e3): xx = pop.mbulge.flatten() / MSOL yy_list = [pop.mass] names = ['new'] if hasattr(pop, '_mass'): yy_list.append(pop._mass) names.append('old') colors = [color, '0.5'] handles = [] if xx.size > nplot: cut = np.random.choice(xx.size, int(nplot), replace=False) print("Plotting {:.1e}/{:.1e} data-points".format(nplot, xx.size)) else: cut = slice(None) for ii, yy in enumerate(yy_list): yy = yy.flatten() / MSOL data = np.log10([xx[cut], yy[cut]]) kale.plot.dist2d( data, ax=ax, color=colors[ii], hist=False, contour=True, median=True, mask_dense=True, ) hh, = plt.plot([], [], color=colors[ii]) handles.append(hh) return handles, names def plot_gwb(gwb, color=None, uniform=False, nreals=5): """Plot a GW background from the given `Grav_Waves` instance. Plots samples, confidence intervals, power-law, and adds twin-Hz axis (x2). Parameters ---------- gwb : `gravwaves.Grav_Waves` (subclass) instance Returns ------- fig : `mpl.figure.Figure` New matplotlib figure instance. """ fig, ax = figax( scale='log', xlabel=r'frequency $[\mathrm{yr}^{-1}]$', ylabel=r'characteristic strain $[\mathrm{h}_c]$' ) if uniform: color = ax._get_lines.get_next_color() _draw_gwb_sample(ax, gwb, color=color, num=nreals) _draw_gwb_conf(ax, gwb, color=color) _draw_plaw(ax, gwb.freqs*YR, f0=1, color='0.5', lw=2.0, ls='--') _twin_hz(ax, nano=True, fs=12) return fig def _draw_gwb_sample(ax, gwb, num=10, back=True, fore=True, color=None): back_flag = back fore_flag = fore back = gwb.back fore = gwb.fore freqs = gwb.freqs * YR pl = dict(alpha=0.5, color=color, lw=0.8) plsel = dict(alpha=0.85, color=color, lw=1.6) sc = dict(alpha=0.25, s=20, fc=color, lw=0.0, ec='none') scsel = dict(alpha=0.50, s=40, ec='k', fc=color, lw=1.0) cut = np.random.choice(back.shape[1], num, replace=False) sel = cut[0] cut = cut[1:] color_gen = None color_sel = None if back_flag: hands_gen = ax.plot(freqs, back[:, cut], **pl) hands_sel, = ax.plot(freqs, back[:, sel], **plsel) color_gen = [hh.get_color() for hh in hands_gen] color_sel = hands_sel.get_color() if color is None: sc['fc'] = color_gen scsel['fc'] = color_sel if fore_flag: yy = fore[:, cut] xx = freqs[:, np.newaxis] * np.ones_like(yy) dx = np.diff(freqs) dx = np.concatenate([[dx[0]], dx])[:, np.newaxis] dx *= 0.2 xx += np.random.normal(0, dx, np.shape(xx)) # xx += np.random.uniform(-dx, dx, np.shape(xx)) xx = np.clip(xx, freqs[0]*0.75, None) ax.scatter(xx, yy, **sc) yy = fore[:, sel] xx = freqs ax.scatter(xx, yy, **scsel) return def _draw_gwb_conf(ax, gwb, **kwargs): conf = [0.25, 0.50, 0.75] freqs = gwb.freqs * YR back = gwb.back kwargs.setdefault('alpha', 0.5) kwargs.setdefault('lw', 0.5) conf = np.percentile(back, 100*np.array(conf), axis=-1) ax.fill_between(freqs, conf[0], conf[-1], **kwargs) kwargs['alpha'] = 1.0 - 0.5*(1.0 - kwargs['alpha']) ax.plot(freqs, conf[1], **kwargs) return '''