# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet
# SPDX-License-Identifier: GPL-3.0-or-later
# nngt/simulation/nest_plot.py
""" Utility functions to plot NEST simulated activity """
import itertools
import logging
from matplotlib.colors import ColorConverter
import numpy as np
import nest
import nngt
from nngt.analysis import total_firing_rate
from nngt.lib import InvalidArgument, nonstring_container, is_integer
from nngt.lib.sorting import _sort_groups, _sort_neurons
from nngt.lib.logger import _log_message
from nngt.plot import palette_discrete, markers
from nngt.plot.plt_properties import _set_new_plot, _set_ax_lims
from .nest_utils import nest_version, spike_rec, _get_nest_gids
logger = logging.getLogger(__name__)
# --------------------- #
# Plotting the activity #
# --------------------- #
[docs]def plot_activity(gid_recorder=None, record=None, network=None, gids=None,
axis=None, show=False, limits=None, histogram=False,
title=None, fignum=None, label=None, sort=None,
average=False, normalize=1., decimate=None, transparent=True,
kernel_center=0., kernel_std=None, resolution=None,
cut_gaussian=5., **kwargs):
'''
Plot the monitored activity.
.. versionchanged:: 1.2
Switched `hist` to `histogram` and default value to False.
.. versionchanged:: 1.0.1
Added `axis` parameter, restored missing `fignum` parameter.
Parameters
----------
gid_recorder : tuple or list of tuples, optional (default: None)
The gids of the recording devices. If None, then all existing
spike_recs are used.
record : tuple or list, optional (default: None)
List of the monitored variables for each device. If `gid_recorder` is
None, record can also be None and only spikes are considered.
network : :class:`~nngt.Network` or subclass, optional (default: None)
Network which activity will be monitored.
gids : tuple, optional (default: None)
NEST gids of the neurons which should be monitored.
axis : matplotlib axis object, optional (default: new one)
Axis that should be use to plot the activity. This takes precedence
over `fignum`.
show : bool, optional (default: False)
Whether to show the plot right away or to wait for the next plt.show().
histogram : bool, optional (default: False)
Whether to display the histogram when plotting spikes rasters.
limits : tuple, optional (default: None)
Time limits of the plot (if not specified, times of first and last
spike for raster plots).
title : str, optional (default: None)
Title of the plot.
fignum : int, or dict, optional (default: None)
Plot the activity on an existing figure (from ``figure.number``). This
parameter is ignored if `axis` is provided.
label : str or list, optional (default: None)
Add labels to the plot (one per recorder).
sort : str or list, optional (default: None)
Sort neurons using a topological property ("in-degree", "out-degree",
"total-degree" or "betweenness"), an activity-related property
("firing_rate" or neuronal property) or a user-defined list of sorted
neuron ids. Sorting is performed by increasing value of the `sort`
property from bottom to top inside each group.
normalize : float or list, optional (default: None)
Normalize the recorded results by a given float. If a list is provided,
there should be one entry per voltmeter or multimeter in the recorders.
If the recording was done through `monitor_groups`, the population can
be passed to normalize the data by the nuber of nodes in each group.
decimate : int or list of ints, optional (default: None)
Represent only a fraction of the spiking neurons; only one neuron in
`decimate` will be represented (e.g. setting `decimate` to 5 will lead
to only 20% of the neurons being represented). If a list is provided,
it must have one entry per NeuralGroup in the population.
kernel_center : float, optional (default: 0.)
Temporal shift of the Gaussian kernel, in ms (for the histogram).
kernel_std : float, optional (default: 0.5% of simulation time)
Characteristic width of the Gaussian kernel (standard deviation) in ms
(for the histogram).
resolution : float or array, optional (default: `0.1*kernel_std`)
The resolution at which the firing rate values will be computed.
Choosing a value smaller than `kernel_std` is strongly advised.
If resolution is an array, it will be considered as the times were the
firing rate should be computed (for the histogram).
cut_gaussian : float, optional (default: 5.)
Range over which the Gaussian will be computed (for the histogram).
By default, we consider the 5-sigma range. Decreasing this value will
increase speed at the cost of lower fidelity; increasing it with
increase the fidelity at the cost of speed.
**kwargs : dict
"color" and "alpha" values can be overriden here.
Warning
-------
Sorting with "firing_rate" only works if NEST gids form a continuous
integer range.
Returns
-------
lines : list of lists of :class:`matplotlib.lines.Line2D`
Lines containing the data that was plotted, grouped by figure.
'''
import matplotlib.pyplot as plt
recorders = _get_nest_gids([])
lst_labels, lines, axes, labels = [], {}, {}, {}
# normalize recorders and recordables
if gid_recorder is not None:
assert record is not None, "`record` must also be provided."
if len(record) != len(gid_recorder):
raise InvalidArgument('`record` must either be the same for all '
'recorders, or contain one entry per '
'recorder in `gid_recorder`')
for rec in gid_recorder:
if nest_version == 3:
recorders = _get_nest_gids(gid_recorder)
else:
if isinstance(gid_recorder[0], tuple):
recorders.append(rec)
else:
recorders.append((rec,))
else:
prop = {'model': spike_rec}
if nest_version == 3:
recorders = nest.GetNodes(properties=prop)
else:
recorders = [
(gid,) for gid in nest.GetNodes((0,), properties=prop)[0]
]
record = tuple("spikes" for _ in range(len(recorders)))
# get gids and groups
gids = network.nest_gids if (gids is None and network is not None) \
else gids
if gids is None:
gids = []
for rec in recorders:
gids.extend(nest.GetStatus(rec)[0]["events"]["senders"])
gids = np.unique(gids)
num_group = 1 if network is None else len(network.population)
num_lines = max(num_group, len(recorders))
# sorting
sorted_neurons = np.array([])
if len(gids):
sorted_neurons = np.arange(
np.max(gids) + 1).astype(int) - np.min(gids) + 1
attr = None
if sort is not None:
assert network is not None, "`network` is required for sorting."
if nonstring_container(sort):
attr = sort
sorted_neurons = _sort_neurons(attr, gids, network)
sort = "user defined sort"
else:
data = None
if sort.lower() in ("firing_rate", "b2"): # get senders
data = [[], []]
for rec in recorders:
info = nest.GetStatus(rec)[0]
if str(info["model"]) == spike_rec:
data[0].extend(info["events"]["senders"])
data[1].extend(info["events"]["times"])
data = np.array(data).T
sorted_neurons, attr = _sort_neurons(
sort, gids, network, data=data, return_attr=True)
elif network is not None and network.is_spatial():
sorted_neurons, attr = _sort_neurons(
"space", gids, network, data=None, return_attr=True)
# spikes plotting
colors = palette_discrete(np.linspace(0, 1, num_lines))
num_raster, num_detec, num_meter = 0, 0, 0
fignums = fignum if isinstance(fignum, dict) else {}
decim = []
if decimate is None:
decim = [None for _ in range(num_lines)]
elif is_integer(decimate):
decim = [decimate for _ in range(num_lines)]
elif nonstring_container(decimate):
assert len(decimate) == num_lines, "`decimate` should have one " +\
"entry per plot."
decim = decimate
else:
raise AttributeError(
"`decimate` must be either an int or a list of `int`.")
# set labels
if label is None:
lst_labels = [None for _ in range(len(recorders))]
else:
if isinstance(label, str):
lst_labels = [label]
else:
lst_labels = label
if len(label) != len(recorders):
_log_message(logger, "WARNING",
'Incorrect length for `label`: expecting {} but got '
'{}.\nIgnoring.'.format(len(recorders), len(label)))
lst_labels = [None for _ in range(len(recorders))]
datasets = []
max_time = 0.
for rec in recorders:
info = nest.GetStatus(rec)[0]
if len(info["events"]["times"]):
max_time = max(max_time, np.max(info["events"]["times"]))
datasets.append(info)
if kernel_std is None:
kernel_std = max_time*0.005
if resolution is None:
resolution = 0.5*kernel_std
# plot
for info, var, lbl in zip(datasets, record, lst_labels):
fnum = fignums.get(info["model"], fignum)
if info["model"] not in labels:
labels[info["model"]] = []
lines[info["model"]] = []
if str(info["model"]) == spike_rec:
if spike_rec in axes:
axis = axes[spike_rec]
c = colors[num_raster]
times, senders = info["events"]["times"], info["events"]["senders"]
sorted_ids = sorted_neurons[senders]
l = raster_plot(times, sorted_ids, color=c, show=False,
limits=limits, sort=sort, fignum=fnum, axis=axis,
decimate=decim[num_raster], sort_attribute=attr,
network=network, histogram=histogram,
transparent=transparent,
hist_ax=axes.get('histogram', None),
kernel_center=kernel_center,
kernel_std=kernel_std, resolution=resolution,
cut_gaussian=cut_gaussian)
num_raster += 1
if l:
fig_raster = l[0].figure.number
fignums[spike_rec] = fig_raster
axes[spike_rec] = l[0].axes
labels[spike_rec].append(lbl)
lines[spike_rec].extend(l)
if histogram:
axes['histogram'] = l[1].axes
elif "detector" in str(info["model"]):
c = colors[num_detec]
times, senders = info["events"]["times"], info["events"]["senders"]
sorted_ids = sorted_neurons[senders]
l = raster_plot(times, sorted_ids, fignum=fnum, color=c, axis=axis,
show=False, histogram=histogram, limits=limits,
kernel_center=kernel_center,
kernel_std=kernel_std, resolution=resolution,
cut_gaussian=cut_gaussian)
if l:
fig_detect = l[0].figure.number
num_detec += 1
fignums[info["model"]] = fig_detect
labels[info["model"]].append(lbl)
lines[info["model"]].extend(l)
if histogram:
axes['histogram'] = l[1].axes
else:
da_time = info["events"]["times"]
# prepare axis setup
fig = None
if axis is None:
fig = plt.figure(fnum)
fignums[info["model"]] = fig.number
else:
fig = axis.get_figure()
lines_tmp, labels_tmp = [], []
if nonstring_container(var):
m_colors = palette_discrete(np.linspace(0, 1, len(var)))
axes = fig.axes
if axis is not None:
# multiple y axes on a single subplot, adapted from
# https://matplotlib.org/examples/pylab_examples/
# multiple_yaxis_with_spines.html
axes = [axis]
axis.name = var[0]
if len(var) > 1:
axes.append(axis.twinx())
axes[-1].name = var[1]
if len(var) > 2:
fig.subplots_adjust(right=0.75)
for i, name in zip(range(len(var)-2), var[2:]):
new_ax = axis.twinx()
new_ax.spines["right"].set_position(
("axes", 1.2*(i+1)))
axes.append(new_ax)
_make_patch_spines_invisible(new_ax)
new_ax.spines["right"].set_visible(True)
axes[-1].name = name
if not axes:
axes = _set_new_plot(fig.number, names=var)[1]
labels_tmp = [lbl for _ in range(len(var))]
for subvar, c in zip(var, m_colors):
c = kwargs.get('color', c)
alpha = kwargs.get('alpha', 1)
for ax in axes:
if ax.name == subvar:
da_subvar = info["events"][subvar]
if isinstance(normalize, nngt.NeuralPop):
da_subvar /= normalize[num_meter].size
elif nonstring_container(normalize):
da_subvar /= normalize[num_meter]
elif normalize is not None:
da_subvar /= normalize
lines_tmp.extend(
ax.plot(da_time, da_subvar, color=c,
alpha=alpha))
ax.set_ylabel(subvar)
ax.set_xlabel("time")
if limits is not None:
ax.set_xlim(limits[0], limits[1])
else:
num_axes, ax = len(fig.axes), axis
if axis is None:
ax = fig.add_subplot(num_axes + 1, 1, num_axes + 1)
da_var = info["events"][var]
c = kwargs.get('color', None)
alpha = kwargs.get('alpha', 1)
lines_tmp.extend(ax.plot(da_time, da_var/normalize, color=c,
alpha=alpha))
labels_tmp.append(lbl)
ax.set_ylabel(var)
ax.set_xlabel("time")
labels[info["model"]].extend(labels_tmp)
lines[info["model"]].extend(lines_tmp)
num_meter += 1
if spike_rec in axes:
ax = axes[spike_rec]
if limits is not None:
ax.set_xlim(limits[0], limits[1])
else:
t_min, t_max, idx_min, idx_max = np.inf, -np.inf, np.inf, -np.inf
for l in ax.lines:
t_max = max(np.max(l.get_xdata()), t_max)
t_min = min(np.min(l.get_xdata()), t_max)
idx_min = min(np.min(l.get_ydata()), idx_min)
idx_max = max(np.max(l.get_ydata()), idx_max)
dt = t_max - t_min
didx = idx_max - idx_min
pc = 0.02
if not np.any(np.isinf((t_max, t_min))):
ax.set_xlim([t_min - pc*dt, t_max + pc*dt])
if not np.any(np.isinf((idx_min, idx_max))):
ax.set_ylim([idx_min - pc*didx, idx_max + pc*didx])
for recorder in fignums:
fig = plt.figure(fignums[recorder])
if title is not None:
fig.suptitle(title)
if label is not None:
fig.legend(lines[recorder], labels[recorder])
if show:
plt.show()
return lines
[docs]def raster_plot(times, senders, limits=None, title="Spike raster",
histogram=False, num_bins=1000, color="b", decimate=None,
axis=None, fignum=None, label=None, show=True, sort=None,
sort_attribute=None, network=None, transparent=True,
kernel_center=0., kernel_std=30., resolution=None,
cut_gaussian=5., **kwargs):
"""
Plotting routine that constructs a raster plot along with
an optional histogram.
.. versionchanged:: 1.2
Switched `hist` to `histogram`.
.. versionchanged:: 1.0.1
Added `axis` parameter.
Parameters
----------
times : list or :class:`numpy.ndarray`
Spike times.
senders : list or :class:`numpy.ndarray`
Index for the spiking neuron for each time in `times`.
limits : tuple, optional (default: None)
Time limits of the plot (if not specified, times of first and last
spike).
title : string, optional (default: 'Spike raster')
Title of the raster plot.
histogram : bool, optional (default: True)
Whether to plot the raster's histogram.
num_bins : int, optional (default: 1000)
Number of bins for the histogram.
color : string or float, optional (default: 'b')
Color of the plot lines and markers.
decimate : int, optional (default: None)
Represent only a fraction of the spiking neurons; only one neuron in
`decimate` will be represented (e.g. setting `decimate` to 10 will lead
to only 10% of the neurons being represented).
axis : matplotlib axis object, optional (default: new one)
Axis that should be use to plot the activity.
fignum : int, optional (default: None)
Id of another raster plot to which the new data should be added.
label : str, optional (default: None)
Label the current data.
show : bool, optional (default: True)
Whether to show the plot right away or to wait for the next plt.show().
kernel_center : float, optional (default: 0.)
Temporal shift of the Gaussian kernel, in ms.
kernel_std : float, optional (default: 30.)
Characteristic width of the Gaussian kernel (standard deviation) in ms.
resolution : float or array, optional (default: `0.1*kernel_std`)
The resolution at which the firing rate values will be computed.
Choosing a value smaller than `kernel_std` is strongly advised.
If resolution is an array, it will be considered as the times were the
firing rate should be computed.
cut_gaussian : float, optional (default: 5.)
Range over which the Gaussian will be computed (for the histogram).
By default, we consider the 5-sigma range. Decreasing this value will
increase speed at the cost of lower fidelity; increasing it with
increase the fidelity at the cost of speed.
Returns
-------
lines : list of :class:`matplotlib.lines.Line2D`
Lines containing the data that was plotted.
"""
import matplotlib.pyplot as plt
lines = []
mpl_kwargs = {k: v for k, v in kwargs.items() if k != 'hist_ax'}
if label is None:
mpl_kwargs['label'] = label
# decimate if necessary
if decimate is not None:
idx_keep = np.where(np.mod(senders, decimate) == 0)[0]
senders = senders[idx_keep]
times = times[idx_keep]
if len(times):
if axis is not None:
fig = axis.get_figure()
else:
fig = plt.figure(fignum)
if transparent:
fig.patch.set_visible(False)
ylabel = "Neuron ID"
xlabel = "Time (ms)"
delta_t = 0.01*(times[-1]-times[0])
if histogram:
ax1, ax2 = None, None
if kwargs.get("hist_ax", None) is None:
num_axes = len(fig.axes)
for i, old_ax in enumerate(fig.axes):
old_ax.change_geometry(num_axes + 2, 1, i+1)
ax1 = fig.add_subplot(num_axes + 2, 1, num_axes + 1)
ax2 = fig.add_subplot(num_axes + 2, 1, num_axes + 2,
sharex=ax1)
else:
ax1 = axis
ax2 = kwargs["hist_ax"]
if limits is not None:
start, stop = limits
keep = (times >= start)&(times <= stop)
times = times[keep]
senders = senders[keep]
lines.extend(ax1.plot(
times, senders, c=color, marker="o", linestyle='None',
mec="k", mew=0.5, ms=4, **mpl_kwargs))
ax1_lines = ax1.lines
if len(ax1_lines) > 1:
t_max = max(ax1_lines[0].get_xdata().max(),times[-1])
ax1.set_xlim([-delta_t, t_max+delta_t])
ax1.set_ylabel(ylabel)
if limits is not None:
ax1.set_xlim(*limits)
fr, fr_times = total_firing_rate(
data=np.array([senders, times]).T, kernel_center=kernel_center,
kernel_std=kernel_std, resolution=resolution,
cut_gaussian=cut_gaussian)
hist_lines = ax2.get_lines()
if hist_lines:
data = hist_lines[-1].get_data()
bottom = data[1]
if limits is None:
dt = fr_times[1] - fr_times[0]
old_times = data[0]
old_start = int(old_times[0] / dt)
new_start = int(fr_times[0] / dt)
old_end = int(old_times[-1] / dt)
new_end = int(fr_times[-1] / dt)
diff_start = new_start-old_start
diff_end = new_end-old_end
if diff_start > 0:
bottom = bottom[diff_start:]
else:
bottom = np.concatenate(
(np.zeros(-diff_start), bottom))
if diff_end > 0:
bottom = np.concatenate((bottom, np.zeros(diff_end)))
else:
bottom = bottom[:diff_end-1]
b_len, h_len = len(bottom), len(fr)
if b_len > h_len:
bottom = bottom[:h_len]
elif b_len < h_len:
bottom = np.concatenate(
(bottom, np.zeros(h_len-b_len)))
else:
bottom = bottom[:-1]
ax2.fill_between(fr_times, fr + bottom, bottom, color=color)
lines.extend(ax2.plot(fr_times, fr + bottom, ls="", marker=""))
else:
ax2.fill_between(fr_times, fr, 0., color=color)
lines.extend(ax2.plot(fr_times, fr, ls="", marker=""))
ax2.set_ylabel("Rate (Hz)")
ax2.set_xlabel(xlabel)
ax2.set_xlim(ax1.get_xlim())
_second_axis(sort, sort_attribute, ax1)
else:
if axis is not None:
ax = axis
else:
num_axes = len(fig.axes)
for i, old_ax in enumerate(fig.axes):
old_ax.change_geometry(num_axes + 1, 1, i+1)
ax = fig.add_subplot(num_axes + 1, 1, num_axes + 1)
if limits is not None:
start, stop = limits
keep = (times >= start)&(times <= stop)
times = times[keep]
senders = senders[keep]
if network is not None:
pop = network.population
colors = palette_discrete(np.linspace(0, 1, len(pop)))
mm = itertools.cycle(markers)
for m, (k, v), c in zip(mm, pop.items(), colors):
keep = np.where(
np.in1d(senders, network.nest_gids[v.ids]))[0]
if len(keep):
if label is None:
mpl_kwargs['label'] = k
lines.extend(ax.plot(
times[keep], senders[keep], c=c, marker=m,
ls='None', mec='k', mew=0.5, ms=4, **mpl_kwargs))
else:
lines.extend(ax.plot(
times, senders, c=color, marker="o", linestyle='None',
mec="k", mew=0.5, ms=4, **mpl_kwargs))
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
if limits is not None:
ax.set_xlim(limits)
else:
_set_ax_lims(ax, np.max(times), np.min(times), np.max(senders),
np.min(senders))
if label is not None:
ax.legend(bbox_to_anchor=(1.1, 1.2))
_second_axis(sort, sort_attribute, ax)
fig.suptitle(title)
if show:
plt.show()
else:
_log_message(logger, "WARNING",
"No activity was detected during the simulation.")
return lines
#-----------------------------------------------------------------------------
# Tools
#------------------------
#
def _fill_between_steps(x, y1, y2=0, h_align='mid'):
'''
Fills a hole in matplotlib: fill_between for step plots.
Parameters :
------------
x : array-like
Array/vector of index values. These are assumed to be equally-spaced.
If not, the result will probably look weird...
y1 : array-like
Array/vector of values to be filled under.
y2 : array-Like
Array/vector or bottom values for filled area. Default is 0.
'''
# First, duplicate the x values
xx = np.repeat(x,2)
# Now: the average x binwidth
xstep = np.repeat((x[1:] - x[:-1]), 2)
xstep = np.concatenate(([xstep[0]], xstep, [xstep[-1]]))
# Now: add one step at end of row.
#~ xx = np.append(xx, xx.max() + xstep[-1])
# Make it possible to change step alignment.
if h_align == 'mid':
xx -= xstep / 2.
elif h_align == 'right':
xx -= xstep
# Also, duplicate each y coordinate in both arrays
y1 = np.repeat(y1,2)#[:-1]
if type(y2) == np.ndarray:
y2 = np.repeat(y2,2)#[:-1]
return xx, y1, y2
def _moving_average (values, window):
weights = np.repeat(1.0, window)/window
sma = np.convolve(values, weights, 'same')
return sma
def _second_axis(sort, sort_attribute, ax):
import matplotlib.pyplot as plt
if sort is not None:
fig = ax.get_figure()
twin = None
for axis in fig.axes:
if axis.get_ylabel() == sort:
twin = axis
break
if twin is None:
asort = np.argsort(sort_attribute)
twin = ax.twinx()
twin.grid(False)
twin.set_ylabel(sort)
plt.draw()
old_ticks = ax.get_yticks()
twin.set_yticks(old_ticks)
twin.set_ylim(ax.get_ylim())
labels = ['' for _ in range(len(old_ticks))]
idx_max = len(sort_attribute) - 1
for i, t in enumerate(old_ticks):
if t >= 0:
idx = min(int(t), idx_max)
labels[i] = _sci_format(sort_attribute[asort[idx]])
twin.set_yticklabels(labels)
def _sci_format(n):
label = ''
if np.abs(n) < 0.01 or np.abs(n) >= 1000:
a = '{:.1E}'.format(n)
label = '$' + a.split('E')[0].rstrip('0').rstrip('.') + '\\cdot 10^{'
exponent = a.split('E')[1].lstrip('0')
if exponent[0] == '-':
exponent = exponent[0] + exponent[1:].lstrip('0')
elif exponent[0] == '+':
exponent = exponent[1:].lstrip('0')
label += exponent + '}$'
elif np.abs(n) >= 100:
label = '{:.0f}'.format(n)
elif np.abs(n) >= 10:
label = '{:.1f}'.format(n)
else:
label = '{:.2f}'.format(n)
return label
def _make_patch_spines_invisible(ax):
ax.set_frame_on(True)
ax.patch.set_visible(False)
for sp in ax.spines.values():
sp.set_visible(False)