# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet
# SPDX-License-Identifier: GPL-3.0-or-later
# nngt/plot/custom_plt.py
""" Matplotlib customization """
import itertools
import logging
from pkg_resources import parse_version
import matplotlib as mpl
from matplotlib.colors import Colormap
from matplotlib.markers import MarkerStyle as MS
import nngt
from nngt.lib.logger import _log_message
logger = logging.getLogger(__name__)
# ---------------- #
# Customize PyPlot #
# ---------------- #
with_seaborn = False
def get_cmap(colormap, n=None):
'''
Get a colormap.
Parameters
----------
colormap : str or colormap
Colormap to return.
n : int, optional
Take `n` samples from the colormap.
'''
if not isinstance(colormap, Colormap):
colormap = mpl.colormaps[colormap]
if n is None:
return colormap
# check version for call to resampled
# @TODO require matplotlib > 3.6.0 in 2024 or something
mpl_version = parse_version(mpl.__version__)
min_version = parse_version("3.6.0")
if mpl_version < min_version:
return colormap._resample(n)
return colormap.resampled(n)
[docs]def palette_continuous(numbers=None):
pal = get_cmap(nngt._config["palette_continuous"])
if numbers is None:
return pal
else:
return pal(numbers)
[docs]def palette_discrete(numbers=None):
pal = get_cmap(nngt._config["palette_discrete"])
if numbers is None:
return pal
else:
return pal(numbers)
# markers list
markers = [m for m in MS.filled_markers if m != '.']
if nngt._config["color_lib"] == "seaborn":
try:
import seaborn as sns
with_seaborn = True
sns.set_style("whitegrid")
def sns_palette(c):
if isinstance(c, float):
pal = sns.color_palette(nngt._config["palette"], 100)
return pal[int(c*100)]
else:
return sns.color_palette(nngt._config["palette"], len(c))
palette_continuous = sns_palette
except ImportError as e:
_log_message(logger, "WARNING",
"`seaborn` requested but could not set it: {}.".format(e))
if not with_seaborn:
try:
mpl.rcParams['font.size'] = 12
mpl.rcParams['font.family'] = 'serif'
if nngt._config['use_tex']:
mpl.rc('text', usetex=True)
mpl.rcParams['axes.labelsize'] = mpl.rcParams['font.size']
mpl.rcParams['axes.titlesize'] = 1.2*mpl.rcParams['font.size']
mpl.rcParams['legend.fontsize'] = mpl.rcParams['font.size']
mpl.rcParams['xtick.labelsize'] = mpl.rcParams['font.size']
mpl.rcParams['ytick.labelsize'] = mpl.rcParams['font.size']
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['savefig.format'] = 'pdf'
mpl.rcParams['xtick.major.size'] = 3
mpl.rcParams['xtick.minor.size'] = 3
mpl.rcParams['xtick.major.width'] = 1
mpl.rcParams['xtick.minor.width'] = 1
mpl.rcParams['ytick.major.size'] = 3
mpl.rcParams['ytick.minor.size'] = 3
mpl.rcParams['ytick.major.width'] = 1
mpl.rcParams['ytick.minor.width'] = 1
mpl.rcParams['legend.frameon'] = False
mpl.rcParams['legend.numpoints'] = 1
mpl.rcParams['axes.linewidth'] = 1
mpl.rcParams['axes.grid'] = True
mpl.rcParams['grid.linestyle'] = ':'
mpl.rcParams['path.simplify'] = True
except Exception as e:
_log_message(logger, "WARNING",
"Error configuring `matplotlib`: {}.".format(e))
def format_exponent(ax, axis='y', pos=(1.,0.), valign="top", halign="right"):
import matplotlib.pyplot as plt
# Change the ticklabel format to scientific format
ax.ticklabel_format(axis=axis, style='sci', scilimits=(-3, 2))
# Get the appropriate axis
if axis == 'y':
ax_axis = ax.yaxis
else:
ax_axis = ax.xaxis
# Run plt.tight_layout() because otherwise the offset text doesn't update
plt.tight_layout()
##### THIS IS A BUG
##### Well, at least it's sub-optimal because you might not
##### want to use tight_layout(). If anyone has a better way of
##### ensuring the offset text is updated appropriately
##### please comment!
# Get the offset value
offset = ax_axis.get_offset_text().get_text()
if len(offset) > 0:
# Get that exponent value and change it into latex format
minus_sign = u'\u2212'
expo = float(offset.replace(minus_sign, '-').split('e')[-1])
offset_text = r'x$\mathregular{10^{%d}}$' %expo
# Turn off the offset text that's calculated automatically
ax_axis.offsetText.set_visible(False)
ax.text(pos[0], pos[1], offset_text, transform=ax.transAxes,
horizontalalignment=halign,
verticalalignment=valign)
return ax