#!/usr/bin/env python
#-*- coding:utf-8 -*-
#
# This file is part of the NNGT project to generate and analyze
# neuronal networks and their activity.
# Copyright (C) 2015-2019 Tanguy Fardet
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
""" Animation tools """
import warnings
import weakref
import subprocess
import numpy as np
import matplotlib as mpl
from matplotlib.lines import Line2D
import matplotlib.animation as anim
from nngt.lib import InvalidArgument
from nngt.lib.sorting import _sort_neurons
from nngt.analysis import total_firing_rate
from .plt_networks import draw_network
# ----------------- #
# Animation classes #
# ----------------- #
class _SpikeAnimator(anim.TimedAnimation):
'''
Generic class to plot raster plot and firing-rate in time for a given
network.
.. warning::
This class is not supposed to be instantiated directly, but only
through Animation2d or AnimationNetwork.
'''
steps = [
1, 5, 10, 20, 25, 50, 100, 200, 250, 500,
1000, 2000, 2500, 5000, 10000, 25000, 50000, 75000, 100000, 250000
]
def __init__(self, source, sort_neurons=None,
network=None, grid=(2, 4), pos_raster=(0, 2),
span_raster=(1, 2), pos_rate=(1, 2),
span_rate=(1, 2), make_rate=True, **kwargs):
'''
Generate a SubplotAnimation instance to plot a network activity.
Parameters
----------
source : NEST gid tuple or str
NEST gid of the `spike_detector`(s) which recorded the network or
path to a file containing the recorded spikes.
Note
----
Calling class is supposed to have defined `self.times`, `self.start`,
`self.duration`, `self.trace`, and `self.timewindow`.
'''
import matplotlib.pyplot as plt
import nest
from nngt.simulation.nest_activity import _get_data
# organization
self.grid = grid
self.has_rate = make_rate
# get data
data_s = _get_data(source)
spikes = np.where(data_s[:, 1] >= self.times[0])[0]
if np.any(spikes):
idx_start = spikes[0]
self.spikes = data_s[:, 1][idx_start:]
self.senders = data_s[:, 0][idx_start:].astype(int)
self._ymax = np.max(self.senders)
self._ymin = np.min(self.senders)
if network is None:
self.num_neurons = int(self._ymax - self._ymin)
else:
self.num_neurons = network.node_nb()
# sorting
if sort_neurons is not None:
if network is not None:
sorted_neurons = _sort_neurons(
sort_neurons, self.senders, network, data=data_s)
self.senders = sorted_neurons[self.senders]
else:
warnings.warn("Could not sort neurons because no " \
+ "`network` was provided.")
dt = self.times[1] - self.times[0]
self.simtime = self.times[-1] - self.times[0]
self.dt = dt
# generate the spike-rate
if make_rate:
self.firing_rate, _ = total_firing_rate(
network, data=data_s, resolution=self.times)
else:
raise RuntimeError("No spikes between {} and {}.".format(
self.start, self.times[-1]))
# figure/canvas: pause/resume and step by step interactions
self.fig = plt.figure(
figsize=kwargs.get("figsize", (8, 6)), dpi=kwargs.get("dpi", 75))
self.pause = False
self.pause_after = False
self.event = None
self.increment = 1
self.fig.canvas.mpl_connect('button_press_event', self.on_click)
self.fig.canvas.mpl_connect('key_press_event', self.on_keyboard_press)
self.fig.canvas.mpl_connect(
'key_release_event', self.on_keyboard_release)
# Axes for spikes and spike-rate/other representations
self.spks = plt.subplot2grid(
grid, pos_raster, rowspan=span_raster[0], colspan=span_raster[1])
self.second = plt.subplot2grid(
grid, pos_rate, rowspan=span_rate[0], colspan=span_rate[1],
sharex=self.spks)
# lines
self.line_spks_ = Line2D(
[], [], ls='None', marker='o', color='black', ms=2, mew=0)
self.line_spks_a = Line2D(
[], [], ls='None', marker='o', color='red', ms=2, mew=0)
self.line_second_ = Line2D([], [], color='black')
self.line_second_a = Line2D([], [], color='red', linewidth=2)
self.line_second_e = Line2D(
[], [], color='red', marker='o', markeredgecolor='r')
# Spikes raster plot
kw_args = {}
if self.timewindow != self.duration:
kw_args['xlim'] = (self.start,
min(self.simtime, self.timewindow + self.start))
ylim = (self._ymin, self._ymax)
self.lines_raster = [self.line_spks_, self.line_spks_a]
self.set_axis(self.spks, xlabel='Time (ms)', ylabel='Neuron',
lines=self.lines_raster, ylim=ylim, set_xticks=True, **kw_args)
self.lines_second = [
self.line_second_, self.line_second_a, self.line_second_e]
# Rate plot
if make_rate:
self.set_axis(
self.second, xlabel='Time (ms)', ylabel='Rate (Hz)',
lines=self.lines_second, ydata=self.firing_rate, **kw_args)
#-------------------------------------------------------------------------
# Axis definition
def set_axis(self, axis, xlabel, ylabel, lines, xdata=None, ydata=None,
**kwargs):
'''
Setup an axis.
Parameters
----------
axis : :class:`matplotlib.axes.Axes` object
xlabel : str
ylabel : str
lines : list of :class:`matplotlib.lines.Line2D` objects
xdata : 1D array-like, optional (default: None)
ydata : 1D array-like, optional (default: None)
**kwargs : dict, optional (default: {})
Optional arguments ("xlim" or "ylim", 2-tuples; "set_xticks",
bool).
'''
axis.set_xlabel(xlabel)
axis.set_ylabel(ylabel)
if kwargs.get('set_xticks', False):
self._make_ticks(self.timewindow)
for line2d in lines:
axis.add_line(line2d)
if 'xlim' in kwargs:
axis.set_xlim(*kwargs['xlim'])
else:
xmin, xmax = self.xticks[0], self.xticks[-1]
axis.set_xlim(_min_axis(xmin, xmax), _max_axis(xmax, xmin))
if 'ylim' in kwargs:
axis.set_ylim(*kwargs['ylim'])
else:
ymin, ymax = np.min(ydata), np.max(ydata)
axis.set_ylim(_min_axis(ymin, ymax), _max_axis(ymax, ymin))
def _draw(self, i, head, head_slice, spike_cum, spike_slice):
self.line_spks_.set_data(
self.spikes[spike_cum], self.senders[spike_cum])
if np.any(spike_slice):
self.line_spks_a.set_data(
self.spikes[spike_slice], self.senders[spike_slice])
else:
self.line_spks_a.set_data([], [])
if self.has_rate:
self.line_second_.set_data(self.times[:i], self.firing_rate[:i])
self.line_second_a.set_data(
self.times[head_slice], self.firing_rate[head_slice])
self.line_second_e.set_data(
self.times[head], self.firing_rate[head])
# set axis limits: 1. check user-defined
current_window = np.diff(self.spks.get_xlim())
default_window = (np.isclose(current_window, self.timewindow)
or np.isclose(current_window, self.simtime - self.start))[0]
# 3. change if necessary
if default_window:
xlims = self.spks.get_xlim()
if self.times[i] >= xlims[1]:
self.spks.set_xlim(
self.times[i] - self.timewindow, self.times[i])
self.second.set_xlim(
self.times[i] - self.timewindow, self.times[i])
elif self.times[i] <= xlims[0]:
self.spks.set_xlim(self.start, self.timewindow + self.start)
def _make_ticks(self, timewindow):
target_num_ticks = np.ceil(self.duration / timewindow * 5)
target_step = self.duration / target_num_ticks
idx_step = np.abs(self.steps-target_step).argmin()
step = self.steps[idx_step]
num_steps = int(self.duration / step) + 2
self.xticks = [self.start + i*step for i in range(num_steps)]
self.xlabels = [str(i) for i in self.xticks]
#-------------------------------------------------------------------------
# User interaction
def on_click(self, event):
if event.button == '2':
if self.pause:
self.pause = False
self.event_source.start()
else:
self.pause = True
self.event_source.stop()
def on_keyboard_press(self, kb_event):
if kb_event.key == ' ':
if self.pause:
self.pause = False
self.event_source.start()
else:
self.pause = True
self.event_source.stop()
else:
if kb_event.key in ('B', 'F', 'N', 'P'):
if self.pause:
self.pause = False
self.pause_after = True # stop at next iteration
self.event_source.start() # restart temporarily
if kb_event.key == 'F':
self.increment *= 2
elif kb_event.key == 'B':
self.increment = max(1, int(self.increment / 2))
self.event = kb_event
def on_keyboard_release(self, kb_event):
if kb_event.key in (' ', 'B', 'F', 'N', 'P'):
if self.pause_after:
self.pause = True
self.pause_after = False
self.event_source.stop() # pause again
self.event = None
def save_movie(self, filename, fps=30, video_encoder='html5', codec="h264",
bitrate=-1, start=None, stop=None, interval=None,
num_frames=None, metadata=None):
'''
Save the animation to a movie file.
Parameters
----------
filename : :obj:`str`
Name of the file where the movie will be saved.
fps : int, optional (default: 30)
Frame per second.
video_encoder : :obj:`str`, optional (default 'html5')
Movie encoding format; either 'ffmpeg', 'html5', or 'imagemagick'.
codec : :obj:`str`, optional (default: "h264")
Codec to use for writing movie; if None, default `animation.codec`
from `matplotlib` will be used.
bitrate : int, optional (default: -1)
Controls size/quality tradeoff for movie. Default (-1) lets utility
auto-determine.
start : float, optional (default: initial time)
Start time, corresponding to the first spike time that will appear
on the video.
stop : float, optional (default: final time)
Stop time, corresponding to the last spike time that will appear
on the video.
interval : int, optional (default: None)
Timestep increment for each new frame. Default saves all
timesteps (often heavy). E.g. setting `interval` to 10 will make
the file 10 times lighter.
num_frames : int, optional (default: None)
Total number of frames that should be saved.
metadata : :obj:`dict`, optional (default: None)
Metadata for the video (e.g. 'title', 'artist', 'comment',
'copyright')
Notes
-----
* ``ffmpeg`` is required for 'ffmpeg' and 'html5' encoders.
To get available formats, type ``ffmpeg -formats`` in a terminal;
type ``ffmpeg -codecs | grep EV`` for available codecs.
* Imagemagick is required for 'imagemagick' encoder.
'''
if interval is not None and num_frames is not None:
raise InvalidArgument("Incompatible arguments `interval` and "
"`num_frames` provided. Choose one.")
elif interval is None and num_frames is None:
self.increment = 1
self.save_count = self.num_frames
elif interval is None:
self.increment = max(1, int(self.num_frames / num_frames))
self.save_count = num_frames
else:
self.increment = interval
self.save_count = int(self.num_frames / interval)
start_frame = 0
stop_frame = self.save_count
if start is not None:
start_frame = int(start / self.dt / self.increment) + 1
self.spks.set_xlim(left=start)
self.second.set_xlim(left=start)
if stop is not None:
stop_frame = int(stop / self.dt / self.increment) + 1
self.spks.set_xlim(right=stop)
self.second.set_xlim(right=stop)
_save_movie(
self, filename, fps, video_encoder, codec, bitrate, metadata,
self.fig.dpi, start_frame, stop_frame)
[docs]class Animation2d(_SpikeAnimator, anim.FuncAnimation):
'''
Class to plot the raster plot, firing-rate, and average trajectory in
a 2D phase-space for a network activity.
'''
def __init__(self, source, multimeter, start=0., timewindow=None,
trace=5., x='time', y='V_m', sort_neurons=None,
network=None, interval=50, vector_field=False, **kwargs):
'''
Generate a SubplotAnimation instance to plot a network activity.
Parameters
----------
source : tuple
NEST gid of the ``spike_detector``(s) which recorded the network.
multimeter : tuple
NEST gid of the ``multimeter``(s) which recorded the network.
timewindow : double, optional (default: None)
Time window which will be shown for the spikes and self.second.
trace : double, optional (default: 5.)
Interval of time (ms) over which the data is overlayed in red.
x : str, optional (default: "time")
Name of the `x`-axis variable (must be either "time" or the name
of a NEST recordable in the `multimeter`).
y : str, optional (default: "V_m")
Name of the `y`-axis variable (must be either "time" or the name
of a NEST recordable in the `multimeter`).
vector_field : bool, optional (default: False)
Whether the :math:`\dot{x}` and :math:`\dot{y}` arrows should be
added to phase space. Requires additional 'dotx' and 'doty'
arguments which are user defined functions to compute the
derivatives of `x` and `x` in time. These functions take 3
parameters, which are `x`, `y`, and `time_dependent`, where the
last parameter is a list of doubles associated to recordables
from the neuron model (see example for details). These recordables
must be declared in a `time_dependent` parameter.
sort_neurons : str or list, optional (default: None)
Sort neurons using a topological property ("in-degree",
"out-degree", "total-degree" or "betweenness"), an activity-related
property ("firing_rate", 'B2') or a user-defined list of sorted
neuron ids. Sorting is performed by increasing value of the
`sort_neurons` property from bottom to top inside each group.
**kwargs : dict, optional (default: {})
Optional arguments such as 'make_rate', 'num_xarrows',
'num_yarrows', 'dotx', 'doty', 'time_dependent', 'recordables',
'arrow_scale'.
'''
import matplotlib.pyplot as plt
import nest
x = "times" if x == "time" else x
y = "times" if y == "time" else y
# get data
data_mm = nest.GetStatus(multimeter)[0]["events"]
self.times = data_mm["times"]
self.num_frames = len(self.times)
idx_start = np.where(self.times >= start)[0][0]
self.idx_start = idx_start
self.times = self.times[idx_start:]
dt = self.times[1] - self.times[0]
self.simtime = self.times[-1]
self.start = start
self.duration = self.simtime - start
self.trace = trace
self.vector_field = vector_field
if timewindow is None:
self.timewindow = self.duration
else:
self.timewindow = min(timewindow, self.duration)
# init _SpikeAnimator parent class (create figure and right axes)
if 'make_rate' not in kwargs:
kwargs['make_rate'] = True
super(Animation2d, self).__init__(
source, sort_neurons=sort_neurons, network=network,
**kwargs)
# Data and axis for phase-space
self.x = data_mm[x][idx_start:] / self.num_neurons
self.y = data_mm[y][idx_start:] / self.num_neurons
self.ps = plt.subplot2grid((2, 4), (0, 0), rowspan=2, colspan=2)
self.ps.grid(False)
# lines
self.line_ps_ = Line2D([], [], color='black')
self.line_ps_a = Line2D([], [], color='red', linewidth=2)
self.line_ps_e = Line2D(
[], [], color='red', marker='o', markeredgecolor='r')
lines = [self.line_ps_, self.line_ps_a, self.line_ps_e]
xlim = (_min_axis(self.x.min()), _max_axis(self.x.max()))
self.set_axis(
self.ps, xlabel=_convert_axis(x), ylabel=_convert_axis(y),
lines=lines, xdata=self.x, ydata=self.y, xlim=xlim)
# For quiver plot (vector field)
nx = kwargs.get('num_xarrows', 20)
ny = kwargs.get('num_yarrows', 20)
scale = kwargs.get('arrow_scale', 30.)
if self.vector_field:
time_dependent_rec = kwargs.get('time_dependent', [])
self.time_dependent = [
data_mm[key][idx_start:] / self.num_neurons
for key in time_dependent_rec
]
self.dotx, self.doty = kwargs['dotx'], kwargs['doty']
xx = np.repeat(np.linspace(xlim[0], xlim[1], nx), ny)
yy = np.tile(np.linspace(self.y.min(), self.y.max(), ny), nx)
self.q = self.ps.quiver(xx, yy, [], [], scale=scale, color='grey')
plt.tight_layout()
anim.FuncAnimation.__init__(self, self.fig, self._draw, self._gen_data,
interval=interval, blit=True)
#-------------------------------------------------------------------------
# Animation instructions
def _gen_data(self):
i = -1
imax = len(self.x) - 1
while i < imax - self.increment:
if not self.pause:
if self.event is not None:
if self.event.key == 'N':
i += self.increment
elif self.event.key == 'P':
i -= self.increment
self.event = None
else:
i += self.increment
yield i
def _draw(self, framedata):
i = int(framedata)
head = i - 1
head_slice = ((self.times > (self.times[i] - self.trace))
& (self.times < self.times[i]))
spike_slice = ((self.spikes > (self.times[i] - self.trace))
& (self.spikes <= self.times[i]))
spike_cum = self.spikes < self.times[i]
lines = []
if self.vector_field:
time_dep = [arr[i] for arr in self.time_dependent]
u = self.dotx(self.q.X, self.q.Y, time_dep)
v = self.doty(self.q.X, self.q.Y, time_dep)
self.q.set_UVC(u, v)
lines.append(self.q)
self.line_ps_.set_data(self.x[:i], self.y[:i])
self.line_ps_a.set_data(self.x[head_slice], self.y[head_slice])
self.line_ps_e.set_data(self.x[i], self.y[i])
lines.extend([self.line_ps_, self.line_ps_a, self.line_ps_e,
self.line_spks_, self.line_spks_a, self.line_second_,
self.line_second_a, self.line_second_e])
super(Animation2d, self)._draw(
i, head, head_slice, spike_cum, spike_slice)
return lines
def _init_draw(self):
'''
Remove ticks from spks/second axes, save background,
then restore state to allow for moveable axes and labels.
'''
xlim = self.spks.get_xlim()
xlabel = self.spks.get_xlabel()
# remove
self.spks.set_xticks([])
self.spks.set_xticklabels([])
self.spks.set_xlabel("")
self.second.set_xticks([])
self.second.set_xticklabels([])
self.second.set_xlabel("")
# background
self.fig.canvas.draw()
self.bg = self.fig.canvas.copy_from_bbox(self.fig.bbox)
# restore
self.spks.set_xticks(self.xticks)
self.spks.set_xticklabels(self.xlabels)
self.spks.set_xlim(*xlim)
self.spks.set_xlabel(xlabel)
self.second.set_xticks(self.xticks)
self.second.set_xticklabels(self.xlabels)
self.second.set_xlim(*xlim)
self.second.set_xlabel(xlabel)
if self.vector_field:
self.q.set_UVC([], [])
# initialize empty lines
lines = [self.line_ps_, self.line_ps_a, self.line_ps_e,
self.line_spks_, self.line_spks_a,
self.line_second_, self.line_second_a, self.line_second_e]
for l in lines:
l.set_data([], [])
[docs]class AnimationNetwork(_SpikeAnimator, anim.FuncAnimation):
'''
Class to plot the raster plot, firing-rate, and space-embedded spiking
activity (neurons on the graph representation flash when spiking) in time.
'''
def __init__(self, source, network, resolution=1., start=0.,
timewindow=None, trace=5., show_spikes=False,
sort_neurons=None, decimate_connections=False,
interval=50, repeat=True, resting_size=None, active_size=None,
**kwargs):
'''
Generate a SubplotAnimation instance to plot a network activity.
Parameters
----------
source : tuple
NEST gid of the ``spike_detector``(s) which recorded the network.
network : :class:`~nngt.SpatialNetwork`
Network embedded in space to plot the actvity of the neurons in
space.
resolution : double, optional (default: None)
Time resolution of the animation.
timewindow : double, optional (default: None)
Time window which will be shown for the spikes and self.second.
trace : double, optional (default: 5.)
Interval of time (ms) over which the data is overlayed in red.
show_spikes : bool, optional (default: True)
Whether a spike trajectory should be displayed on the network.
sort_neurons : str or list, optional (default: None)
Sort neurons using a topological property ("in-degree",
"out-degree", "total-degree" or "betweenness"), an activity-related
property ("firing_rate", 'B2') or a user-defined list of sorted
neuron ids. Sorting is performed by increasing value of the
`sort_neurons` property from bottom to top inside each group.
**kwargs : dict, optional (default: {})
Optional arguments such as 'make_rate', or all arguments for the
:func:`nngt.plot.draw_network`.
'''
import matplotlib.pyplot as plt
import nest
from nngt.simulation.nest_activity import _get_data
self.network = weakref.ref(network)
self.simtime = _get_data(source)[-1, 1]
self.times = np.arange(start, self.simtime + resolution, resolution)
self.num_frames = len(self.times)
self.start = start
self.duration = self.simtime - start
self.trace = trace
self.show_spikes = show_spikes
if timewindow is None:
self.timewindow = self.duration
else:
self.timewindow = min(timewindow, self.duration)
# init _SpikeAnimator parent class (create figure and right axes)
#~ self.decim_conn = 1 if decimate is not None else decimate
self.kwargs = kwargs
cs = kwargs.get('chunksize', 10000)
mpl.rcParams['agg.path.chunksize'] = cs
if 'make_rate' not in kwargs:
kwargs['make_rate'] = True
super(AnimationNetwork, self).__init__(
source, sort_neurons=sort_neurons, network=network,
**kwargs)
self.env = plt.subplot2grid((2, 4), (0, 0), rowspan=2, colspan=2)
# Data and axis for network representation
bbox = self.env.get_window_extent().transformed(
self.fig.dpi_scale_trans.inverted())
area_px = bbox.width * bbox.height * self.fig.dpi**2
# neuron size
n_size = (resting_size if resting_size is not None
else max(2, 0.5*np.sqrt(area_px/self.num_neurons)))
if active_size is None:
active_size = n_size + 2
pos = network.get_positions() # positions of the neurons
self.x = pos[:, 0]
self.y = pos[:, 1]
# neurons
self.line_neurons = Line2D(
[], [], ls='None', marker='o', color='black', ms=n_size, mew=0)
self.line_neurons_a = Line2D(
[], [], ls='None', marker='o', color='red', ms=active_size, mew=0)
self.lines_env = [self.line_neurons, self.line_neurons_a]
xlim = (_min_axis(self.x.min()), _max_axis(self.x.max()))
self.set_axis(self.env, xlabel='Network', ylabel='',
lines=self.lines_env, xdata=self.x, ydata=self.y, xlim=xlim)
# spike trajectory
if show_spikes:
self.line_st_a = Line2D([], [], color='red', linewidth=1)
self.line_st_e = Line2D(
[], [], color='red', marker='d', ms=2, markeredgecolor='r')
self.lines_env.extend((self.line_st_a, self.line_st_e))
# remove the axes and grid from env
self.env.set_xticks([])
self.env.set_yticks([])
self.env.set_xticklabels([])
self.env.set_yticklabels([])
self.env.grid(None)
plt.tight_layout()
anim.FuncAnimation.__init__(
self, self.fig, self._draw, self._gen_data, repeat=repeat,
interval=interval, blit=True)
#-------------------------------------------------------------------------
# Animation instructions
def _gen_data(self):
i = -1
imax = len(self.times) - 1
while i < imax - self.increment:
if not self.pause:
if self.event is not None:
if self.event.key == 'N':
i += self.increment
elif self.event.key == 'P':
i -= self.increment
else:
i += self.increment
yield i
def _draw(self, framedata):
i = int(framedata)
if i == 0: # initialize neurons and connections
self.line_neurons.set_data(self.x, self.y)
#~ self.line_connections.set_data(self.x_conn, self.y_conn)
head = i - 1
head_slice = ((self.times > self.times[i] - self.trace)
& (self.times < self.times[i]))
spike_slice = ((self.spikes > self.times[i] - self.trace)
& (self.spikes <= self.times[i]))
spike_cum = self.spikes < self.times[i]
pos_ids = self.network().id_from_nest_gid(self.senders[spike_slice])
self.line_neurons_a.set_data(self.x[pos_ids], self.y[pos_ids])
if self.show_spikes:
# @todo: make this work for heterogeneous delays
time = self.times[i]
delays = np.average(self.network().get_delays())
departures = self.spikes[spikes_slice]
arrivals = departures + delays
# get the spikers
ids_dep = self.nids[self.senders[spikes_slice]]
degrees = network.get_degrees('out', nodes=ids_dep)
ids_dep = np.repeat(ids_dep, degrees) # repeat based on out-degree
x_dep = self.x[ids_dep]
y_dep = self.y[ids_dep]
# get their out-neighbours
#~ for d, a in zip(departures, arrivals):
super(AnimationNetwork, self)._draw(
i, head, head_slice, spike_cum, spike_slice)
return [self.line_neurons, self.line_neurons_a, self.line_spks_,
self.line_spks_a, self.line_second_, self.line_second_a,
self.line_second_e]
def _init_draw(self):
'''
Remove ticks from spks/second axes, save background,
then restore state to allow for moveable axes and labels.
'''
# remove
xlim = self.spks.get_xlim()
xlabel = self.spks.get_xlabel()
self.spks.set_xticks([])
self.spks.set_xticklabels([])
self.spks.set_xlabel("")
self.second.set_xticks([])
self.second.set_xticklabels([])
self.second.set_xlabel("")
# background
self.fig.canvas.draw()
self.bg = self.fig.canvas.copy_from_bbox(self.fig.bbox)
# restore
self.spks.set_xticks(self.xticks)
self.spks.set_xticklabels(self.xlabels)
self.spks.set_xlim(*xlim)
self.spks.set_xlabel(xlabel)
self.second.set_xticks(self.xticks)
self.second.set_xticklabels(self.xlabels)
self.second.set_xlim(*xlim)
self.second.set_xlabel(xlabel)
# initialize empty lines
lines = [self.line_spks_, self.line_spks_a, self.line_neurons_a,
self.line_second_, self.line_second_a, self.line_second_e,
self.line_neurons]
for l in lines:
l.set_data([], [])
# initialize the neurons and connections between neurons
draw_network(self.network(), ncolor='k', axis=self.env, show=False,
simple_nodes=True, decimate_connections=-1, tight=False,
**self.kwargs)
if self.network().is_spatial():
shape = self.network().shape
xmin, ymin, xmax, ymax = shape.bounds
dx = 0.02*(xmax-xmin)
dy = 0.02*(ymax-ymin)
self.env.set_xlim(xmin-dx, xmax+dx)
self.env.set_ylim(ymin-dy, ymax+dy)
self.line_neurons = self.env.lines[0]
#~ self.line_neurons.set_data(self.x, self.y)
#~ num_edges = self.network().edge_nb()
#~ self.x_conn = np.zeros(3*num_edges)
#~ self.y_conn = np.zeros(3*num_edges)
#~ adj_mat = self.network().adjacency_matrix()
#~ edges = adj_mat.nonzero()
#~ self.x_conn[::3] = self.x[edges[0]] # x position of source nodes
#~ self.x_conn[1::3] = self.x[edges[1]] # x position of target nodes
#~ self.x_conn[2::3] = np.NaN # NaN to separate
#~ self.y_conn[::3] = self.y[edges[0]] # y position of source nodes
#~ self.y_conn[1::3] = self.y[edges[1]] # y position of target nodes
#~ self.y_conn[2::3] = np.NaN # NaN to separate
#~ self.env.plot(
#~ self.x_conn[::self.decim_conn], self.y_conn[::self.decim_conn],
#~ color='k', alpha=0.3, lw=1)
# ----- #
# Tools #
# ----- #
def _max_axis(value, min_val=0.):
if np.isclose(value, 0.):
return -0.02*min_val
elif np.sign(value) > 0.:
return 1.02*value
else:
return 0.98*value
def _min_axis(value, max_val=0.):
if np.isclose(value, 0.):
return -0.02*max_val
elif np.sign(value) < 0.:
return 1.02*value
else:
return 0.98*value
def _convert_axis(axis_name):
lowercase = axis_name.lower()
if lowercase == "times":
return "Time (ms)"
new_name = "$"
i = axis_name.find("_")
if i != -1:
start = lowercase[:i]
if start in ("tau", "alpha", "beta", "gamma", "delta"):
new_name += "\\" + axis_name[:i] + "_{" + axis_name[i+1:] + "}$"
elif start in ("v", "e"):
new_name += axis_name[:i] + "_{" + axis_name[i+1:] + "}$ (mV)"
elif start == "i":
new_name += axis_name[:i] + "_{" + axis_name[i+1:] + "}$ (pA)"
else:
new_name += axis_name[:i] + "_{" + axis_name[i+1:] + "}$"
else:
if lowercase in ("tau", "alpha", "beta", "gamma", "delta"):
new_name += "\\" + lowercase + "$"
elif lowercase == "w":
new_name = "$w$ (pA)"
else:
new_name += lowercase + "$"
return new_name
def _save_movie(animation, filename, fps, video_encoder, codec, bitrate,
metadata, dpi, start, stop):
if filename.endswith('.mp4') or filename.endswith('.avi'):
ffcodec = 'h264' if filename.endswith('.mp4') else 'xvid'
fig = animation.fig
canvas_width, canvas_height = fig.get_size_inches()*fig.dpi
# Open an ffmpeg process
cmdstring = ('ffmpeg',
'-y', '-r', str(fps), # overwrite, 1fps
'-s', '%dx%d' % (canvas_width, canvas_height), # size of image string
'-pix_fmt', 'argb', # format
'-f', 'rawvideo', '-i', '-', # tell ffmpeg to expect raw video from the pipe
'-vcodec', ffcodec, filename) # output encoding
p = subprocess.Popen(cmdstring, stdin=subprocess.PIPE)
# Draw frames and write to the pipe
for i in range(start, stop):
frame = int(i*animation.increment)
# draw the frame
animation._draw(frame)
fig.canvas.draw()
# extract the image as an ARGB string
string = fig.canvas.tostring_argb()
# write to pipe
p.stdin.write(string)
# Finish up
p.communicate()
animation._init_draw()
else:
if metadata is None:
metadata = {"artist": "NNGT"}
encoder = 'ffmpeg' if video_encoder == 'html5' else video_encoder
Writer = anim.writers[encoder]
if video_encoder == 'html5':
codec = 'libx264'
writer = Writer(codec=codec, fps=fps, bitrate=bitrate, metadata=metadata)
animation.save(filename, writer=writer, dpi=dpi)
def _vector_field(q, dotx_func, doty_func, x, y, Is):
'''
Add the vector field of the x and y derivatives in phase space.
Parameters
----------
q : :class:`matplotlib.quiver.Quiver`
Phase space quiver object.
dotx_func : function
User provided function giving :math:`\dot{x} = f(x, y, Is(t))`.
doty_func : function
User provided function giving :math:`\dot{y} = g(x, y, Is(t))`.
x : :class:`numpy.ndarray`.
y : :class:`numpy.ndarray`.
Is : float
Current (time dependent data).
'''
q.set_UVC(dotx_func(x, y, Is), doty_func(x, y, Is))