Source code for nngt.simulation.nest_plot

# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet
# SPDX-License-Identifier: GPL-3.0-or-later
# nngt/simulation/nest_plot.py

""" Utility functions to plot NEST simulated activity """

import itertools
import logging

from matplotlib.colors import ColorConverter
import numpy as np
import nest

import nngt

from nngt.analysis import total_firing_rate
from nngt.lib import InvalidArgument, nonstring_container, is_integer
from nngt.lib.sorting import _sort_groups, _sort_neurons
from nngt.lib.logger import _log_message
from nngt.plot import palette_discrete, markers
from nngt.plot.plt_properties import _set_new_plot, _set_ax_lims

from .nest_utils import nest_version, spike_rec, _get_nest_gids

logger = logging.getLogger(__name__)


# --------------------- #
# Plotting the activity #
# --------------------- #

[docs]def plot_activity(gid_recorder=None, record=None, network=None, gids=None, axis=None, show=False, limits=None, histogram=False, title=None, fignum=None, label=None, sort=None, average=False, normalize=1., decimate=None, transparent=True, kernel_center=0., kernel_std=None, resolution=None, cut_gaussian=5., **kwargs): ''' Plot the monitored activity. .. versionchanged:: 1.2 Switched `hist` to `histogram` and default value to False. .. versionchanged:: 1.0.1 Added `axis` parameter, restored missing `fignum` parameter. Parameters ---------- gid_recorder : tuple or list of tuples, optional (default: None) The gids of the recording devices. If None, then all existing spike_recs are used. record : tuple or list, optional (default: None) List of the monitored variables for each device. If `gid_recorder` is None, record can also be None and only spikes are considered. network : :class:`~nngt.Network` or subclass, optional (default: None) Network which activity will be monitored. gids : tuple, optional (default: None) NEST gids of the neurons which should be monitored. axis : matplotlib axis object, optional (default: new one) Axis that should be use to plot the activity. This takes precedence over `fignum`. show : bool, optional (default: False) Whether to show the plot right away or to wait for the next plt.show(). histogram : bool, optional (default: False) Whether to display the histogram when plotting spikes rasters. limits : tuple, optional (default: None) Time limits of the plot (if not specified, times of first and last spike for raster plots). title : str, optional (default: None) Title of the plot. fignum : int, or dict, optional (default: None) Plot the activity on an existing figure (from ``figure.number``). This parameter is ignored if `axis` is provided. label : str or list, optional (default: None) Add labels to the plot (one per recorder). sort : str or list, optional (default: None) Sort neurons using a topological property ("in-degree", "out-degree", "total-degree" or "betweenness"), an activity-related property ("firing_rate" or neuronal property) or a user-defined list of sorted neuron ids. Sorting is performed by increasing value of the `sort` property from bottom to top inside each group. normalize : float or list, optional (default: None) Normalize the recorded results by a given float. If a list is provided, there should be one entry per voltmeter or multimeter in the recorders. If the recording was done through `monitor_groups`, the population can be passed to normalize the data by the nuber of nodes in each group. decimate : int or list of ints, optional (default: None) Represent only a fraction of the spiking neurons; only one neuron in `decimate` will be represented (e.g. setting `decimate` to 5 will lead to only 20% of the neurons being represented). If a list is provided, it must have one entry per NeuralGroup in the population. kernel_center : float, optional (default: 0.) Temporal shift of the Gaussian kernel, in ms (for the histogram). kernel_std : float, optional (default: 0.5% of simulation time) Characteristic width of the Gaussian kernel (standard deviation) in ms (for the histogram). resolution : float or array, optional (default: `0.1*kernel_std`) The resolution at which the firing rate values will be computed. Choosing a value smaller than `kernel_std` is strongly advised. If resolution is an array, it will be considered as the times were the firing rate should be computed (for the histogram). cut_gaussian : float, optional (default: 5.) Range over which the Gaussian will be computed (for the histogram). By default, we consider the 5-sigma range. Decreasing this value will increase speed at the cost of lower fidelity; increasing it with increase the fidelity at the cost of speed. **kwargs : dict "color" and "alpha" values can be overriden here. Warning ------- Sorting with "firing_rate" only works if NEST gids form a continuous integer range. Returns ------- lines : list of lists of :class:`matplotlib.lines.Line2D` Lines containing the data that was plotted, grouped by figure. ''' import matplotlib.pyplot as plt recorders = _get_nest_gids([]) lst_labels, lines, axes, labels = [], {}, {}, {} # normalize recorders and recordables if gid_recorder is not None: assert record is not None, "`record` must also be provided." if len(record) != len(gid_recorder): raise InvalidArgument('`record` must either be the same for all ' 'recorders, or contain one entry per ' 'recorder in `gid_recorder`') for rec in gid_recorder: if nest_version == 3: recorders = _get_nest_gids(gid_recorder) else: if isinstance(gid_recorder[0], tuple): recorders.append(rec) else: recorders.append((rec,)) else: prop = {'model': spike_rec} if nest_version == 3: recorders = nest.GetNodes(properties=prop) else: recorders = [ (gid,) for gid in nest.GetNodes((0,), properties=prop)[0] ] record = tuple("spikes" for _ in range(len(recorders))) # get gids and groups gids = network.nest_gids if (gids is None and network is not None) \ else gids if gids is None: gids = [] for rec in recorders: gids.extend(nest.GetStatus(rec)[0]["events"]["senders"]) gids = np.unique(gids) num_group = 1 if network is None else len(network.population) num_lines = max(num_group, len(recorders)) # sorting sorted_neurons = np.array([]) if len(gids): sorted_neurons = np.arange( np.max(gids) + 1).astype(int) - np.min(gids) + 1 attr = None if sort is not None: assert network is not None, "`network` is required for sorting." if nonstring_container(sort): attr = sort sorted_neurons = _sort_neurons(attr, gids, network) sort = "user defined sort" else: data = None if sort.lower() in ("firing_rate", "b2"): # get senders data = [[], []] for rec in recorders: info = nest.GetStatus(rec)[0] if str(info["model"]) == spike_rec: data[0].extend(info["events"]["senders"]) data[1].extend(info["events"]["times"]) data = np.array(data).T sorted_neurons, attr = _sort_neurons( sort, gids, network, data=data, return_attr=True) elif network is not None and network.is_spatial(): sorted_neurons, attr = _sort_neurons( "space", gids, network, data=None, return_attr=True) # spikes plotting colors = palette_discrete(np.linspace(0, 1, num_lines)) num_raster, num_detec, num_meter = 0, 0, 0 fignums = fignum if isinstance(fignum, dict) else {} decim = [] if decimate is None: decim = [None for _ in range(num_lines)] elif is_integer(decimate): decim = [decimate for _ in range(num_lines)] elif nonstring_container(decimate): assert len(decimate) == num_lines, "`decimate` should have one " +\ "entry per plot." decim = decimate else: raise AttributeError( "`decimate` must be either an int or a list of `int`.") # set labels if label is None: lst_labels = [None for _ in range(len(recorders))] else: if isinstance(label, str): lst_labels = [label] else: lst_labels = label if len(label) != len(recorders): _log_message(logger, "WARNING", 'Incorrect length for `label`: expecting {} but got ' '{}.\nIgnoring.'.format(len(recorders), len(label))) lst_labels = [None for _ in range(len(recorders))] datasets = [] max_time = 0. for rec in recorders: info = nest.GetStatus(rec)[0] if len(info["events"]["times"]): max_time = max(max_time, np.max(info["events"]["times"])) datasets.append(info) if kernel_std is None: kernel_std = max_time*0.005 if resolution is None: resolution = 0.5*kernel_std # plot for info, var, lbl in zip(datasets, record, lst_labels): fnum = fignums.get(info["model"], fignum) if info["model"] not in labels: labels[info["model"]] = [] lines[info["model"]] = [] if str(info["model"]) == spike_rec: if spike_rec in axes: axis = axes[spike_rec] c = colors[num_raster] times, senders = info["events"]["times"], info["events"]["senders"] sorted_ids = sorted_neurons[senders] l = raster_plot(times, sorted_ids, color=c, show=False, limits=limits, sort=sort, fignum=fnum, axis=axis, decimate=decim[num_raster], sort_attribute=attr, network=network, histogram=histogram, transparent=transparent, hist_ax=axes.get('histogram', None), kernel_center=kernel_center, kernel_std=kernel_std, resolution=resolution, cut_gaussian=cut_gaussian) num_raster += 1 if l: fig_raster = l[0].figure.number fignums[spike_rec] = fig_raster axes[spike_rec] = l[0].axes labels[spike_rec].append(lbl) lines[spike_rec].extend(l) if histogram: axes['histogram'] = l[1].axes elif "detector" in str(info["model"]): c = colors[num_detec] times, senders = info["events"]["times"], info["events"]["senders"] sorted_ids = sorted_neurons[senders] l = raster_plot(times, sorted_ids, fignum=fnum, color=c, axis=axis, show=False, histogram=histogram, limits=limits, kernel_center=kernel_center, kernel_std=kernel_std, resolution=resolution, cut_gaussian=cut_gaussian) if l: fig_detect = l[0].figure.number num_detec += 1 fignums[info["model"]] = fig_detect labels[info["model"]].append(lbl) lines[info["model"]].extend(l) if histogram: axes['histogram'] = l[1].axes else: da_time = info["events"]["times"] # prepare axis setup fig = None if axis is None: fig = plt.figure(fnum) fignums[info["model"]] = fig.number else: fig = axis.get_figure() lines_tmp, labels_tmp = [], [] if nonstring_container(var): m_colors = palette_discrete(np.linspace(0, 1, len(var))) axes = fig.axes if axis is not None: # multiple y axes on a single subplot, adapted from # https://matplotlib.org/examples/pylab_examples/ # multiple_yaxis_with_spines.html axes = [axis] axis.name = var[0] if len(var) > 1: axes.append(axis.twinx()) axes[-1].name = var[1] if len(var) > 2: fig.subplots_adjust(right=0.75) for i, name in zip(range(len(var)-2), var[2:]): new_ax = axis.twinx() new_ax.spines["right"].set_position( ("axes", 1.2*(i+1))) axes.append(new_ax) _make_patch_spines_invisible(new_ax) new_ax.spines["right"].set_visible(True) axes[-1].name = name if not axes: axes = _set_new_plot(fig.number, names=var)[1] labels_tmp = [lbl for _ in range(len(var))] for subvar, c in zip(var, m_colors): c = kwargs.get('color', c) alpha = kwargs.get('alpha', 1) for ax in axes: if ax.name == subvar: da_subvar = info["events"][subvar] if isinstance(normalize, nngt.NeuralPop): da_subvar /= normalize[num_meter].size elif nonstring_container(normalize): da_subvar /= normalize[num_meter] elif normalize is not None: da_subvar /= normalize lines_tmp.extend( ax.plot(da_time, da_subvar, color=c, alpha=alpha)) ax.set_ylabel(subvar) ax.set_xlabel("time") if limits is not None: ax.set_xlim(limits[0], limits[1]) else: num_axes, ax = len(fig.axes), axis if axis is None: ax = fig.add_subplot(num_axes + 1, 1, num_axes + 1) da_var = info["events"][var] c = kwargs.get('color', None) alpha = kwargs.get('alpha', 1) lines_tmp.extend(ax.plot(da_time, da_var/normalize, color=c, alpha=alpha)) labels_tmp.append(lbl) ax.set_ylabel(var) ax.set_xlabel("time") labels[info["model"]].extend(labels_tmp) lines[info["model"]].extend(lines_tmp) num_meter += 1 if spike_rec in axes: ax = axes[spike_rec] if limits is not None: ax.set_xlim(limits[0], limits[1]) else: t_min, t_max, idx_min, idx_max = np.inf, -np.inf, np.inf, -np.inf for l in ax.lines: t_max = max(np.max(l.get_xdata()), t_max) t_min = min(np.min(l.get_xdata()), t_max) idx_min = min(np.min(l.get_ydata()), idx_min) idx_max = max(np.max(l.get_ydata()), idx_max) dt = t_max - t_min didx = idx_max - idx_min pc = 0.02 if not np.any(np.isinf((t_max, t_min))): ax.set_xlim([t_min - pc*dt, t_max + pc*dt]) if not np.any(np.isinf((idx_min, idx_max))): ax.set_ylim([idx_min - pc*didx, idx_max + pc*didx]) for recorder in fignums: fig = plt.figure(fignums[recorder]) if title is not None: fig.suptitle(title) if label is not None: fig.legend(lines[recorder], labels[recorder]) if show: plt.show() return lines
[docs]def raster_plot(times, senders, limits=None, title="Spike raster", histogram=False, num_bins=1000, color="b", decimate=None, axis=None, fignum=None, label=None, show=True, sort=None, sort_attribute=None, network=None, transparent=True, kernel_center=0., kernel_std=30., resolution=None, cut_gaussian=5., **kwargs): """ Plotting routine that constructs a raster plot along with an optional histogram. .. versionchanged:: 1.2 Switched `hist` to `histogram`. .. versionchanged:: 1.0.1 Added `axis` parameter. Parameters ---------- times : list or :class:`numpy.ndarray` Spike times. senders : list or :class:`numpy.ndarray` Index for the spiking neuron for each time in `times`. limits : tuple, optional (default: None) Time limits of the plot (if not specified, times of first and last spike). title : string, optional (default: 'Spike raster') Title of the raster plot. histogram : bool, optional (default: True) Whether to plot the raster's histogram. num_bins : int, optional (default: 1000) Number of bins for the histogram. color : string or float, optional (default: 'b') Color of the plot lines and markers. decimate : int, optional (default: None) Represent only a fraction of the spiking neurons; only one neuron in `decimate` will be represented (e.g. setting `decimate` to 10 will lead to only 10% of the neurons being represented). axis : matplotlib axis object, optional (default: new one) Axis that should be use to plot the activity. fignum : int, optional (default: None) Id of another raster plot to which the new data should be added. label : str, optional (default: None) Label the current data. show : bool, optional (default: True) Whether to show the plot right away or to wait for the next plt.show(). kernel_center : float, optional (default: 0.) Temporal shift of the Gaussian kernel, in ms. kernel_std : float, optional (default: 30.) Characteristic width of the Gaussian kernel (standard deviation) in ms. resolution : float or array, optional (default: `0.1*kernel_std`) The resolution at which the firing rate values will be computed. Choosing a value smaller than `kernel_std` is strongly advised. If resolution is an array, it will be considered as the times were the firing rate should be computed. cut_gaussian : float, optional (default: 5.) Range over which the Gaussian will be computed (for the histogram). By default, we consider the 5-sigma range. Decreasing this value will increase speed at the cost of lower fidelity; increasing it with increase the fidelity at the cost of speed. Returns ------- lines : list of :class:`matplotlib.lines.Line2D` Lines containing the data that was plotted. """ import matplotlib.pyplot as plt lines = [] mpl_kwargs = {k: v for k, v in kwargs.items() if k != 'hist_ax'} if label is None: mpl_kwargs['label'] = label # decimate if necessary if decimate is not None: idx_keep = np.where(np.mod(senders, decimate) == 0)[0] senders = senders[idx_keep] times = times[idx_keep] if len(times): if axis is not None: fig = axis.get_figure() else: fig = plt.figure(fignum) if transparent: fig.patch.set_visible(False) ylabel = "Neuron ID" xlabel = "Time (ms)" delta_t = 0.01*(times[-1]-times[0]) if histogram: ax1, ax2 = None, None if kwargs.get("hist_ax", None) is None: num_axes = len(fig.axes) for i, old_ax in enumerate(fig.axes): old_ax.change_geometry(num_axes + 2, 1, i+1) ax1 = fig.add_subplot(num_axes + 2, 1, num_axes + 1) ax2 = fig.add_subplot(num_axes + 2, 1, num_axes + 2, sharex=ax1) else: ax1 = axis ax2 = kwargs["hist_ax"] if limits is not None: start, stop = limits keep = (times >= start)&(times <= stop) times = times[keep] senders = senders[keep] lines.extend(ax1.plot( times, senders, c=color, marker="o", linestyle='None', mec="k", mew=0.5, ms=4, **mpl_kwargs)) ax1_lines = ax1.lines if len(ax1_lines) > 1: t_max = max(ax1_lines[0].get_xdata().max(),times[-1]) ax1.set_xlim([-delta_t, t_max+delta_t]) ax1.set_ylabel(ylabel) if limits is not None: ax1.set_xlim(*limits) fr, fr_times = total_firing_rate( data=np.array([senders, times]).T, kernel_center=kernel_center, kernel_std=kernel_std, resolution=resolution, cut_gaussian=cut_gaussian) hist_lines = ax2.get_lines() if hist_lines: data = hist_lines[-1].get_data() bottom = data[1] if limits is None: dt = fr_times[1] - fr_times[0] old_times = data[0] old_start = int(old_times[0] / dt) new_start = int(fr_times[0] / dt) old_end = int(old_times[-1] / dt) new_end = int(fr_times[-1] / dt) diff_start = new_start-old_start diff_end = new_end-old_end if diff_start > 0: bottom = bottom[diff_start:] else: bottom = np.concatenate( (np.zeros(-diff_start), bottom)) if diff_end > 0: bottom = np.concatenate((bottom, np.zeros(diff_end))) else: bottom = bottom[:diff_end-1] b_len, h_len = len(bottom), len(fr) if b_len > h_len: bottom = bottom[:h_len] elif b_len < h_len: bottom = np.concatenate( (bottom, np.zeros(h_len-b_len))) else: bottom = bottom[:-1] ax2.fill_between(fr_times, fr + bottom, bottom, color=color) lines.extend(ax2.plot(fr_times, fr + bottom, ls="", marker="")) else: ax2.fill_between(fr_times, fr, 0., color=color) lines.extend(ax2.plot(fr_times, fr, ls="", marker="")) ax2.set_ylabel("Rate (Hz)") ax2.set_xlabel(xlabel) ax2.set_xlim(ax1.get_xlim()) _second_axis(sort, sort_attribute, ax1) else: if axis is not None: ax = axis else: num_axes = len(fig.axes) for i, old_ax in enumerate(fig.axes): old_ax.change_geometry(num_axes + 1, 1, i+1) ax = fig.add_subplot(num_axes + 1, 1, num_axes + 1) if limits is not None: start, stop = limits keep = (times >= start)&(times <= stop) times = times[keep] senders = senders[keep] if network is not None: pop = network.population colors = palette_discrete(np.linspace(0, 1, len(pop))) mm = itertools.cycle(markers) for m, (k, v), c in zip(mm, pop.items(), colors): keep = np.where( np.in1d(senders, network.nest_gids[v.ids]))[0] if len(keep): if label is None: mpl_kwargs['label'] = k lines.extend(ax.plot( times[keep], senders[keep], c=c, marker=m, ls='None', mec='k', mew=0.5, ms=4, **mpl_kwargs)) else: lines.extend(ax.plot( times, senders, c=color, marker="o", linestyle='None', mec="k", mew=0.5, ms=4, **mpl_kwargs)) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) if limits is not None: ax.set_xlim(limits) else: _set_ax_lims(ax, np.max(times), np.min(times), np.max(senders), np.min(senders)) if label is not None: ax.legend(bbox_to_anchor=(1.1, 1.2)) _second_axis(sort, sort_attribute, ax) fig.suptitle(title) if show: plt.show() else: _log_message(logger, "WARNING", "No activity was detected during the simulation.") return lines
#----------------------------------------------------------------------------- # Tools #------------------------ # def _fill_between_steps(x, y1, y2=0, h_align='mid'): ''' Fills a hole in matplotlib: fill_between for step plots. Parameters : ------------ x : array-like Array/vector of index values. These are assumed to be equally-spaced. If not, the result will probably look weird... y1 : array-like Array/vector of values to be filled under. y2 : array-Like Array/vector or bottom values for filled area. Default is 0. ''' # First, duplicate the x values xx = np.repeat(x,2) # Now: the average x binwidth xstep = np.repeat((x[1:] - x[:-1]), 2) xstep = np.concatenate(([xstep[0]], xstep, [xstep[-1]])) # Now: add one step at end of row. #~ xx = np.append(xx, xx.max() + xstep[-1]) # Make it possible to change step alignment. if h_align == 'mid': xx -= xstep / 2. elif h_align == 'right': xx -= xstep # Also, duplicate each y coordinate in both arrays y1 = np.repeat(y1,2)#[:-1] if type(y2) == np.ndarray: y2 = np.repeat(y2,2)#[:-1] return xx, y1, y2 def _moving_average (values, window): weights = np.repeat(1.0, window)/window sma = np.convolve(values, weights, 'same') return sma def _second_axis(sort, sort_attribute, ax): import matplotlib.pyplot as plt if sort is not None: fig = ax.get_figure() twin = None for axis in fig.axes: if axis.get_ylabel() == sort: twin = axis break if twin is None: asort = np.argsort(sort_attribute) twin = ax.twinx() twin.grid(False) twin.set_ylabel(sort) plt.draw() old_ticks = ax.get_yticks() twin.set_yticks(old_ticks) twin.set_ylim(ax.get_ylim()) labels = ['' for _ in range(len(old_ticks))] idx_max = len(sort_attribute) - 1 for i, t in enumerate(old_ticks): if t >= 0: idx = min(int(t), idx_max) labels[i] = _sci_format(sort_attribute[asort[idx]]) twin.set_yticklabels(labels) def _sci_format(n): label = '' if np.abs(n) < 0.01 or np.abs(n) >= 1000: a = '{:.1E}'.format(n) label = '$' + a.split('E')[0].rstrip('0').rstrip('.') + '\\cdot 10^{' exponent = a.split('E')[1].lstrip('0') if exponent[0] == '-': exponent = exponent[0] + exponent[1:].lstrip('0') elif exponent[0] == '+': exponent = exponent[1:].lstrip('0') label += exponent + '}$' elif np.abs(n) >= 100: label = '{:.0f}'.format(n) elif np.abs(n) >= 10: label = '{:.1f}'.format(n) else: label = '{:.2f}'.format(n) return label def _make_patch_spines_invisible(ax): ax.set_frame_on(True) ax.patch.set_visible(False) for sp in ax.spines.values(): sp.set_visible(False)