#!/usr/bin/env python
#-*- coding:utf-8 -*-
#
# This file is part of the NNGT project to generate and analyze
# neuronal networks and their activity.
# Copyright (C) 2015-2019 Tanguy Fardet
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# nest_plot.py
# This file is part of the NNGT module
# Distributed as a free software, in the hope that it will be useful, under the
# terms of the GNU General Public License.
""" Utility functions to plot NEST simulated activity """
import itertools
import logging
from matplotlib.colors import ColorConverter
import numpy as np
import nest
import nngt
from nngt.analysis import total_firing_rate
from nngt.lib import InvalidArgument, nonstring_container, is_integer
from nngt.lib.sorting import _sort_groups, _sort_neurons
from nngt.lib.logger import _log_message
from nngt.plot import palette_discrete, markers
from nngt.plot.plt_properties import _set_new_plot, _set_ax_lims
logger = logging.getLogger(__name__)
# --------------------- #
# Plotting the activity #
# --------------------- #
[docs]def plot_activity(gid_recorder=None, record=None, network=None, gids=None,
axis=None, show=False, limits=None, histogram=False,
title=None, fignum=None, label=None, sort=None,
average=False, normalize=1., decimate=None, transparent=True,
kernel_center=0., kernel_std=None, resolution=None,
cut_gaussian=5.):
'''
Plot the monitored activity.
.. versionchanged:: 1.2
Switched `hist` to `histogram` and default value to False.
.. versionchanged:: 1.0.1
Added `axis` parameter, restored missing `fignum` parameter.
Parameters
----------
gid_recorder : tuple or list of tuples, optional (default: None)
The gids of the recording devices. If None, then all existing
"spike_detector"s are used.
record : tuple or list, optional (default: None)
List of the monitored variables for each device. If `gid_recorder` is
None, record can also be None and only spikes are considered.
network : :class:`~nngt.Network` or subclass, optional (default: None)
Network which activity will be monitored.
gids : tuple, optional (default: None)
NEST gids of the neurons which should be monitored.
axis : matplotlib axis object, optional (default: new one)
Axis that should be use to plot the activity. This takes precedence
over `fignum`.
show : bool, optional (default: False)
Whether to show the plot right away or to wait for the next plt.show().
histogram : bool, optional (default: False)
Whether to display the histogram when plotting spikes rasters.
limits : tuple, optional (default: None)
Time limits of the plot (if not specified, times of first and last
spike for raster plots).
title : str, optional (default: None)
Title of the plot.
fignum : int, or dict, optional (default: None)
Plot the activity on an existing figure (from ``figure.number``). This
parameter is ignored if `axis` is provided.
label : str or list, optional (default: None)
Add labels to the plot (one per recorder).
sort : str or list, optional (default: None)
Sort neurons using a topological property ("in-degree", "out-degree",
"total-degree" or "betweenness"), an activity-related property
("firing_rate" or neuronal property) or a user-defined list of sorted
neuron ids. Sorting is performed by increasing value of the `sort`
property from bottom to top inside each group.
normalize : float or list, optional (default: None)
Normalize the recorded results by a given float. If a list is provided,
there should be one entry per voltmeter or multimeter in the recorders.
If the recording was done through `monitor_groups`, the population can
be passed to normalize the data by the nuber of nodes in each group.
decimate : int or list of ints, optional (default: None)
Represent only a fraction of the spiking neurons; only one neuron in
`decimate` will be represented (e.g. setting `decimate` to 5 will lead
to only 20% of the neurons being represented). If a list is provided,
it must have one entry per NeuralGroup in the population.
kernel_center : float, optional (default: 0.)
Temporal shift of the Gaussian kernel, in ms (for the histogram).
kernel_std : float, optional (default: 0.5% of simulation time)
Characteristic width of the Gaussian kernel (standard deviation) in ms
(for the histogram).
resolution : float or array, optional (default: `0.1*kernel_std`)
The resolution at which the firing rate values will be computed.
Choosing a value smaller than `kernel_std` is strongly advised.
If resolution is an array, it will be considered as the times were the
firing rate should be computed (for the histogram).
cut_gaussian : float, optional (default: 5.)
Range over which the Gaussian will be computed (for the histogram).
By default, we consider the 5-sigma range. Decreasing this value will
increase speed at the cost of lower fidelity; increasing it with
increase the fidelity at the cost of speed.
Warning
-------
Sorting with "firing_rate" only works if NEST gids form a continuous
integer range.
Returns
-------
lines : list of lists of :class:`matplotlib.lines.Line2D`
Lines containing the data that was plotted, grouped by figure.
'''
import matplotlib.pyplot as plt
lst_rec, lst_labels, lines, axes, labels = [], [], {}, {}, {}
# normalize recorders and recordables
if gid_recorder is not None:
assert record is not None, "`record` must also be provided."
if len(record) != len(gid_recorder):
raise InvalidArgument('`record` must either be the same for all '
'recorders, or contain one entry per '
'recorder in `gid_recorder`')
for rec in gid_recorder:
if isinstance(gid_recorder[0], tuple):
lst_rec.append(rec[0])
else:
lst_rec.append(rec)
else:
lst_rec = nest.GetNodes(
(0,), properties={'model': 'spike_detector'})[0]
record = tuple("spikes" for _ in range(len(lst_rec)))
# get gids and groups
gids = network.nest_gids if (gids is None and network is not None) \
else gids
if gids is None:
gids = []
for rec in lst_rec:
gids.extend(nest.GetStatus([rec])[0]["events"]["senders"])
gids = np.unique(gids)
num_group = 1 if network is None else len(network.population)
num_lines = max(num_group, len(lst_rec))
# sorting
sorted_neurons = np.array([])
if len(gids):
sorted_neurons = np.arange(
np.max(gids) + 1).astype(int) - np.min(gids) + 1
attr = None
if sort is not None:
assert network is not None, "`network` is required for sorting."
if nonstring_container(sort):
attr = sort
sorted_neurons = _sort_neurons(attr, gids, network)
sort = "user defined sort"
else:
data = None
if sort.lower() in ("firing_rate", "b2"): # get senders
data = [[], []]
for rec in lst_rec:
info = nest.GetStatus([rec])[0]
if str(info["model"]) == "spike_detector":
data[0].extend(info["events"]["senders"])
data[1].extend(info["events"]["times"])
data = np.array(data).T
sorted_neurons, attr = _sort_neurons(
sort, gids, network, data=data, return_attr=True)
elif network is not None and network.is_spatial():
sorted_neurons, attr = _sort_neurons(
"space", gids, network, data=None, return_attr=True)
# spikes plotting
colors = palette_discrete(np.linspace(0, 1, num_lines))
num_raster, num_detec, num_meter = 0, 0, 0
fignums = fignum if isinstance(fignum, dict) else {}
decim = []
if decimate is None:
decim = [None for _ in range(num_lines)]
elif is_integer(decimate):
decim = [decimate for _ in range(num_lines)]
elif nonstring_container(decimate):
assert len(decimate) == num_lines, "`decimate` should have one " +\
"entry per plot."
decim = decimate
else:
raise AttributeError(
"`decimate` must be either an int or a list of `int`.")
# set labels
if label is None:
lst_labels = [None for _ in range(len(lst_rec))]
else:
if isinstance(label, str):
lst_labels = [label]
else:
lst_labels = label
if len(label) != len(lst_rec):
_log_message(logger, "WARNING",
'Incorrect length for `label`: expecting {} but got '
'{}.\nIgnoring.'.format(len(lst_rec), len(label)))
lst_labels = [None for _ in range(len(lst_rec))]
datasets = []
max_time = 0.
for rec in lst_rec:
info = nest.GetStatus([rec])[0]
if len(info["events"]["times"]):
max_time = max(max_time, np.max(info["events"]["times"]))
datasets.append(info)
if kernel_std is None:
kernel_std = max_time*0.005
if resolution is None:
resolution = 0.5*kernel_std
# plot
for info, var, lbl in zip(datasets, record, lst_labels):
fnum = fignums.get(info["model"], fignum)
if info["model"] not in labels:
labels[info["model"]] = []
lines[info["model"]] = []
if str(info["model"]) == "spike_detector":
if "spike_detector" in axes:
axis = axes["spike_detector"]
c = colors[num_raster]
times, senders = info["events"]["times"], info["events"]["senders"]
sorted_ids = sorted_neurons[senders]
l = raster_plot(times, sorted_ids, color=c, show=False,
limits=limits, sort=sort, fignum=fnum, axis=axis,
decimate=decim[num_raster], sort_attribute=attr,
network=network, histogram=histogram,
transparent=transparent,
hist_ax=axes.get('histogram', None),
kernel_center=kernel_center,
kernel_std=kernel_std, resolution=resolution,
cut_gaussian=cut_gaussian)
num_raster += 1
if l:
fig_raster = l[0].figure.number
fignums['spike_detector'] = fig_raster
axes['spike_detector'] = l[0].axes
labels["spike_detector"].append(lbl)
lines["spike_detector"].extend(l)
if histogram:
axes['histogram'] = l[1].axes
elif "detector" in str(info["model"]):
c = colors[num_detec]
times, senders = info["events"]["times"], info["events"]["senders"]
sorted_ids = sorted_neurons[senders]
l = raster_plot(times, sorted_ids, fignum=fnum, color=c, axis=axis,
show=False, histogram=histogram, limits=limits,
kernel_center=kernel_center,
kernel_std=kernel_std, resolution=resolution,
cut_gaussian=cut_gaussian)
if l:
fig_detect = l[0].figure.number
num_detec += 1
fignums[info["model"]] = fig_detect
labels[info["model"]].append(lbl)
lines[info["model"]].extend(l)
if histogram:
axes['histogram'] = l[1].axes
else:
da_time = info["events"]["times"]
# prepare axis setup
fig = None
if axis is None:
fig = plt.figure(fnum)
fignums[info["model"]] = fig.number
else:
fig = axis.get_figure()
lines_tmp, labels_tmp = [], []
if nonstring_container(var):
m_colors = palette_discrete(np.linspace(0, 1, len(var)))
axes = fig.axes
if axis is not None:
# multiple y axes on a single subplot, adapted from
# https://matplotlib.org/examples/pylab_examples/
# multiple_yaxis_with_spines.html
axes = [axis]
axis.name = var[0]
if len(var) > 1:
axes.append(axis.twinx())
axes[-1].name = var[1]
if len(var) > 2:
fig.subplots_adjust(right=0.75)
for i, name in zip(range(len(var)-2), var[2:]):
new_ax = axis.twinx()
new_ax.spines["right"].set_position(
("axes", 1.2*(i+1)))
axes.append(new_ax)
_make_patch_spines_invisible(new_ax)
new_ax.spines["right"].set_visible(True)
axes[-1].name = name
if not axes:
axes = _set_new_plot(fig.number, names=var)[1]
labels_tmp = [lbl for _ in range(len(var))]
for subvar, c in zip(var, m_colors):
for ax in axes:
if ax.name == subvar:
da_subvar = info["events"][subvar]
if isinstance(normalize, nngt.NeuralPop):
da_subvar /= normalize[num_meter].size
elif nonstring_container(normalize):
da_subvar /= normalize[num_meter]
elif normalize is not None:
da_subvar /= normalize
lines_tmp.extend(
ax.plot(da_time, da_subvar, color=c))
ax.set_ylabel(subvar)
ax.set_xlabel("time")
if limits is not None:
ax.set_xlim(limits[0], limits[1])
else:
num_axes, ax = len(fig.axes), axis
if axis is None:
ax = fig.add_subplot(num_axes + 1, 1, num_axes + 1)
da_var = info["events"][var]
lines_tmp.extend(ax.plot(da_time, da_var/normalize))
labels_tmp.append(lbl)
ax.set_ylabel(var)
ax.set_xlabel("time")
labels[info["model"]].extend(labels_tmp)
lines[info["model"]].extend(lines_tmp)
num_meter += 1
if "spike_detector" in axes:
ax = axes['spike_detector']
if limits is not None:
ax.set_xlim(limits[0], limits[1])
else:
t_min, t_max, idx_min, idx_max = np.inf, -np.inf, np.inf, -np.inf
for l in ax.lines:
t_max = max(np.max(l.get_xdata()), t_max)
t_min = min(np.min(l.get_xdata()), t_max)
idx_min = min(np.min(l.get_ydata()), idx_min)
idx_max = max(np.max(l.get_ydata()), idx_max)
dt = t_max - t_min
didx = idx_max - idx_min
pc = 0.02
if not np.any(np.isinf((t_max, t_min))):
ax.set_xlim([t_min - pc*dt, t_max + pc*dt])
if not np.any(np.isinf((idx_min, idx_max))):
ax.set_ylim([idx_min - pc*didx, idx_max + pc*didx])
for recorder in fignums:
fig = plt.figure(fignums[recorder])
if title is not None:
fig.suptitle(title)
if label is not None:
fig.legend(lines[recorder], labels[recorder])
if show:
plt.show()
return lines
[docs]def raster_plot(times, senders, limits=None, title="Spike raster",
histogram=False, num_bins=1000, color="b", decimate=None,
axis=None, fignum=None, label=None, show=True, sort=None,
sort_attribute=None, network=None, transparent=True,
kernel_center=0., kernel_std=30., resolution=None,
cut_gaussian=5., **kwargs):
"""
Plotting routine that constructs a raster plot along with
an optional histogram.
.. versionchanged:: 1.2
Switched `hist` to `histogram`.
.. versionchanged:: 1.0.1
Added `axis` parameter.
Parameters
----------
times : list or :class:`numpy.ndarray`
Spike times.
senders : list or :class:`numpy.ndarray`
Index for the spiking neuron for each time in `times`.
limits : tuple, optional (default: None)
Time limits of the plot (if not specified, times of first and last
spike).
title : string, optional (default: 'Spike raster')
Title of the raster plot.
histogram : bool, optional (default: True)
Whether to plot the raster's histogram.
num_bins : int, optional (default: 1000)
Number of bins for the histogram.
color : string or float, optional (default: 'b')
Color of the plot lines and markers.
decimate : int, optional (default: None)
Represent only a fraction of the spiking neurons; only one neuron in
`decimate` will be represented (e.g. setting `decimate` to 10 will lead
to only 10% of the neurons being represented).
axis : matplotlib axis object, optional (default: new one)
Axis that should be use to plot the activity.
fignum : int, optional (default: None)
Id of another raster plot to which the new data should be added.
label : str, optional (default: None)
Label the current data.
show : bool, optional (default: True)
Whether to show the plot right away or to wait for the next plt.show().
kernel_center : float, optional (default: 0.)
Temporal shift of the Gaussian kernel, in ms.
kernel_std : float, optional (default: 30.)
Characteristic width of the Gaussian kernel (standard deviation) in ms.
resolution : float or array, optional (default: `0.1*kernel_std`)
The resolution at which the firing rate values will be computed.
Choosing a value smaller than `kernel_std` is strongly advised.
If resolution is an array, it will be considered as the times were the
firing rate should be computed.
cut_gaussian : float, optional (default: 5.)
Range over which the Gaussian will be computed (for the histogram).
By default, we consider the 5-sigma range. Decreasing this value will
increase speed at the cost of lower fidelity; increasing it with
increase the fidelity at the cost of speed.
Returns
-------
lines : list of :class:`matplotlib.lines.Line2D`
Lines containing the data that was plotted.
"""
import matplotlib.pyplot as plt
lines = []
mpl_kwargs = {k: v for k, v in kwargs.items() if k != 'hist_ax'}
if label is None:
mpl_kwargs['label'] = label
# decimate if necessary
if decimate is not None:
idx_keep = np.where(np.mod(senders, decimate) == 0)[0]
senders = senders[idx_keep]
times = times[idx_keep]
if len(times):
if axis is not None:
fig = axis.get_figure()
else:
fig = plt.figure(fignum)
if transparent:
fig.patch.set_visible(False)
ylabel = "Neuron ID"
xlabel = "Time (ms)"
delta_t = 0.01*(times[-1]-times[0])
if histogram:
ax1, ax2 = None, None
if kwargs.get("hist_ax", None) is None:
num_axes = len(fig.axes)
for i, old_ax in enumerate(fig.axes):
old_ax.change_geometry(num_axes + 2, 1, i+1)
ax1 = fig.add_subplot(num_axes + 2, 1, num_axes + 1)
ax2 = fig.add_subplot(num_axes + 2, 1, num_axes + 2,
sharex=ax1)
else:
ax1 = axis
ax2 = kwargs["hist_ax"]
if limits is not None:
start, stop = limits
keep = (times >= start)&(times <= stop)
times = times[keep]
senders = senders[keep]
lines.extend(ax1.plot(
times, senders, c=color, marker="o", linestyle='None',
mec="k", mew=0.5, ms=4, **mpl_kwargs))
ax1_lines = ax1.lines
if len(ax1_lines) > 1:
t_max = max(ax1_lines[0].get_xdata().max(),times[-1])
ax1.set_xlim([-delta_t, t_max+delta_t])
ax1.set_ylabel(ylabel)
if limits is not None:
ax1.set_xlim(*limits)
fr, fr_times = total_firing_rate(
data=np.array([senders, times]).T, kernel_center=kernel_center,
kernel_std=kernel_std, resolution=resolution,
cut_gaussian=cut_gaussian)
hist_lines = ax2.get_lines()
if hist_lines:
data = hist_lines[-1].get_data()
bottom = data[1]
if limits is None:
dt = fr_times[1] - fr_times[0]
old_times = data[0]
old_start = int(old_times[0] / dt)
new_start = int(fr_times[0] / dt)
old_end = int(old_times[-1] / dt)
new_end = int(fr_times[-1] / dt)
diff_start = new_start-old_start
diff_end = new_end-old_end
if diff_start > 0:
bottom = bottom[diff_start:]
else:
bottom = np.concatenate(
(np.zeros(-diff_start), bottom))
if diff_end > 0:
bottom = np.concatenate((bottom, np.zeros(diff_end)))
else:
bottom = bottom[:diff_end-1]
b_len, h_len = len(bottom), len(fr)
if b_len > h_len:
bottom = bottom[:h_len]
elif b_len < h_len:
bottom = np.concatenate(
(bottom, np.zeros(h_len-b_len)))
else:
bottom = bottom[:-1]
ax2.fill_between(fr_times, fr + bottom, bottom, color=color)
lines.extend(ax2.plot(fr_times, fr + bottom, ls="", marker=""))
else:
ax2.fill_between(fr_times, fr, 0., color=color)
lines.extend(ax2.plot(fr_times, fr, ls="", marker=""))
ax2.set_ylabel("Rate (Hz)")
ax2.set_xlabel(xlabel)
ax2.set_xlim(ax1.get_xlim())
_second_axis(sort, sort_attribute, ax1)
else:
if axis is not None:
ax = axis
else:
num_axes = len(fig.axes)
for i, old_ax in enumerate(fig.axes):
old_ax.change_geometry(num_axes + 1, 1, i+1)
ax = fig.add_subplot(num_axes + 1, 1, num_axes + 1)
if limits is not None:
start, stop = limits
keep = (times >= start)&(times <= stop)
times = times[keep]
senders = senders[keep]
if network is not None:
pop = network.population
colors = palette_discrete(np.linspace(0, 1, len(pop)))
mm = itertools.cycle(markers)
for m, (k, v), c in zip(mm, pop.items(), colors):
keep = np.where(
np.in1d(senders, network.nest_gids[v.ids]))[0]
if len(keep):
if label is None:
mpl_kwargs['label'] = k
lines.extend(ax.plot(
times[keep], senders[keep], c=c, marker=m,
ls='None', mec='k', mew=0.5, ms=4, **mpl_kwargs))
else:
lines.extend(ax.plot(
times, senders, c=color, marker="o", linestyle='None',
mec="k", mew=0.5, ms=4, **mpl_kwargs))
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
if limits is not None:
ax.set_xlim(limits)
else:
_set_ax_lims(ax, np.max(times), np.min(times), np.max(senders),
np.min(senders))
if label is not None:
ax.legend(bbox_to_anchor=(1.1, 1.2))
_second_axis(sort, sort_attribute, ax)
fig.suptitle(title)
if show:
plt.show()
else:
_log_message(logger, "WARNING",
"No activity was detected during the simulation.")
return lines
#-----------------------------------------------------------------------------
# Tools
#------------------------
#
def _fill_between_steps(x, y1, y2=0, h_align='mid'):
'''
Fills a hole in matplotlib: fill_between for step plots.
Parameters :
------------
x : array-like
Array/vector of index values. These are assumed to be equally-spaced.
If not, the result will probably look weird...
y1 : array-like
Array/vector of values to be filled under.
y2 : array-Like
Array/vector or bottom values for filled area. Default is 0.
'''
# First, duplicate the x values
xx = np.repeat(x,2)
# Now: the average x binwidth
xstep = np.repeat((x[1:] - x[:-1]), 2)
xstep = np.concatenate(([xstep[0]], xstep, [xstep[-1]]))
# Now: add one step at end of row.
#~ xx = np.append(xx, xx.max() + xstep[-1])
# Make it possible to change step alignment.
if h_align == 'mid':
xx -= xstep / 2.
elif h_align == 'right':
xx -= xstep
# Also, duplicate each y coordinate in both arrays
y1 = np.repeat(y1,2)#[:-1]
if type(y2) == np.ndarray:
y2 = np.repeat(y2,2)#[:-1]
return xx, y1, y2
def _moving_average (values, window):
weights = np.repeat(1.0, window)/window
sma = np.convolve(values, weights, 'same')
return sma
def _second_axis(sort, sort_attribute, ax):
import matplotlib.pyplot as plt
if sort is not None:
fig = ax.get_figure()
twin = None
for axis in fig.axes:
if axis.get_ylabel() == sort:
twin = axis
break
if twin is None:
asort = np.argsort(sort_attribute)
twin = ax.twinx()
twin.grid(False)
twin.set_ylabel(sort)
plt.draw()
old_ticks = ax.get_yticks()
twin.set_yticks(old_ticks)
twin.set_ylim(ax.get_ylim())
labels = ['' for _ in range(len(old_ticks))]
idx_max = len(sort_attribute) - 1
for i, t in enumerate(old_ticks):
if t >= 0:
idx = min(int(t), idx_max)
labels[i] = _sci_format(sort_attribute[asort[idx]])
twin.set_yticklabels(labels)
def _sci_format(n):
label = ''
if np.abs(n) < 0.01 or np.abs(n) >= 1000:
a = '{:.1E}'.format(n)
label = '$' + a.split('E')[0].rstrip('0').rstrip('.') + '\\cdot 10^{'
exponent = a.split('E')[1].lstrip('0')
if exponent[0] == '-':
exponent = exponent[0] + exponent[1:].lstrip('0')
elif exponent[0] == '+':
exponent = exponent[1:].lstrip('0')
label += exponent + '}$'
elif np.abs(n) >= 100:
label = '{:.0f}'.format(n)
elif np.abs(n) >= 10:
label = '{:.1f}'.format(n)
else:
label = '{:.2f}'.format(n)
return label
def _make_patch_spines_invisible(ax):
ax.set_frame_on(True)
ax.patch.set_visible(False)
for sp in ax.spines.values():
sp.set_visible(False)