#!/usr/bin/env python
#-*- coding:utf-8 -*-
#
# This file is part of the NNGT project to generate and analyze
# neuronal networks and their activity.
# Copyright (C) 2015-2017 Tanguy Fardet
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
""" Animation tools """
import warnings
import weakref
import subprocess
import numpy as np
import matplotlib as mpl
from matplotlib.lines import Line2D
import matplotlib.animation as anim
from nngt.lib.sorting import _sort_neurons
from nngt.analysis import total_firing_rate
from .plt_networks import draw_network
# ----------------- #
# Animation classes #
# ----------------- #
class _SpikeAnimator(anim.TimedAnimation):
'''
Generic class to plot raster plot and firing-rate in time for a given
network.
.. warning::
This class is not supposed to be instantiated directly, but only
through Animation2d or AnimationNetwork.
'''
steps = [
1, 5, 10, 20, 25, 50, 100, 200, 250, 500,
1000, 2000, 2500, 5000, 10000, 25000, 50000, 75000, 100000, 250000
]
def __init__(self, source, sort_neurons=None,
network=None, grid=(2, 4), pos_raster=(0, 2),
span_raster=(1, 2), pos_rate=(1, 2),
span_rate=(1, 2), make_rate=True, **kwargs):
'''
Generate a SubplotAnimation instance to plot a network activity.
Parameters
----------
source : NEST gid tuple or str
NEST gid of the `spike_detector`(s) which recorded the network or
path to a file containing the recorded spikes.
Note
----
Calling class is supposed to have defined `self.times`, `self.start`,
`self.duration`, `self.trace`, and `self.timewindow`.
'''
import matplotlib.pyplot as plt
import nest
from nngt.simulation.nest_activity import _get_data
# organization
self.grid = grid
self.has_rate = make_rate
# get data
data_s = _get_data(source)
spikes = np.where(data_s[:, 1] >= self.times[0])[0]
if np.any(spikes):
idx_start = spikes[0]
self.spikes = data_s[:, 1][idx_start:]
self.senders = data_s[:, 0][idx_start:].astype(int)
self._ymax = np.max(self.senders)
self._ymin = np.min(self.senders)
if network is None:
self.num_neurons = int(self._ymax - self._ymin)
else:
self.num_neurons = network.node_nb()
# sorting
if sort_neurons is not None:
if network is not None:
sorted_neurons = _sort_neurons(
sort_neurons, self.senders, network, data=data_s)
self.senders = sorted_neurons[self.senders]
else:
warnings.warn("Could not sort neurons because no " \
+ "`network` was provided.")
dt = self.times[1] - self.times[0]
self.simtime = self.times[-1] - self.times[0]
self.dt = dt
# generate the spike-rate
if make_rate:
self.firing_rate, _ = total_firing_rate(
network, data=data_s, resolution=self.times)
else:
raise RuntimeError("No spikes between {} and {}.".format(
self.start, self.times[-1]))
# figure/canvas: pause/resume and step by step interactions
self.fig = plt.figure(
figsize=kwargs.get("figsize", (8, 6)), dpi=kwargs.get("dpi", 75))
self.pause = False
self.pause_after = False
self.event = None
self.increment = 1
self.fig.canvas.mpl_connect('button_press_event', self.on_click)
self.fig.canvas.mpl_connect('key_press_event', self.on_keyboard_press)
self.fig.canvas.mpl_connect(
'key_release_event', self.on_keyboard_release)
# Axes for spikes and spike-rate/other representations
self.spks = plt.subplot2grid(
grid, pos_raster, rowspan=span_raster[0], colspan=span_raster[1])
self.second = plt.subplot2grid(
grid, pos_rate, rowspan=span_rate[0], colspan=span_rate[1],
sharex=self.spks)
# lines
self.line_spks_ = Line2D(
[], [], ls='None', marker='o', color='black', ms=2, mew=0)
self.line_spks_a = Line2D(
[], [], ls='None', marker='o', color='red', ms=2, mew=0)
self.line_second_ = Line2D([], [], color='black')
self.line_second_a = Line2D([], [], color='red', linewidth=2)
self.line_second_e = Line2D(
[], [], color='red', marker='o', markeredgecolor='r')
# Spikes raster plot
kw_args = {}
if self.timewindow != self.duration:
kw_args['xlim'] = (self.start,
min(self.simtime, self.timewindow + self.start))
ylim = (self._ymin, self._ymax)
self.lines_raster = [self.line_spks_, self.line_spks_a]
self.set_axis(self.spks, xlabel='Time (ms)', ylabel='Neuron',
lines=self.lines_raster, ylim=ylim, set_xticks=True, **kw_args)
self.lines_second = [
self.line_second_, self.line_second_a, self.line_second_e]
# Rate plot
if make_rate:
self.set_axis(
self.second, xlabel='Time (ms)', ylabel='Rate (Hz)',
lines=self.lines_second, ydata=self.firing_rate, **kw_args)
#-------------------------------------------------------------------------
# Axis definition
def set_axis(self, axis, xlabel, ylabel, lines, xdata=None, ydata=None,
**kwargs):
'''
Setup an axis.
Parameters
----------
axis : :class:`matplotlib.axes.Axes` object
xlabel : str
ylabel : str
lines : list of :class:`matplotlib.lines.Line2D` objects
xdata : 1D array-like, optional (default: None)
ydata : 1D array-like, optional (default: None)
**kwargs : dict, optional (default: {})
Optional arguments ("xlim" or "ylim", 2-tuples; "set_xticks",
bool).
'''
axis.set_xlabel(xlabel)
axis.set_ylabel(ylabel)
if kwargs.get('set_xticks', False):
self._make_ticks(self.timewindow)
for line2d in lines:
axis.add_line(line2d)
if 'xlim' in kwargs:
axis.set_xlim(*kwargs['xlim'])
else:
xmin, xmax = self.xticks[0], self.xticks[-1]
axis.set_xlim(_min_axis(xmin, xmax), _max_axis(xmax, xmin))
if 'ylim' in kwargs:
axis.set_ylim(*kwargs['ylim'])
else:
ymin, ymax = np.min(ydata), np.max(ydata)
axis.set_ylim(_min_axis(ymin, ymax), _max_axis(ymax, ymin))
def _draw(self, i, head, head_slice, spike_cum, spike_slice):
self.line_spks_.set_data(
self.spikes[spike_cum], self.senders[spike_cum])
if np.any(spike_slice):
self.line_spks_a.set_data(
self.spikes[spike_slice], self.senders[spike_slice])
else:
self.line_spks_a.set_data([], [])
if self.has_rate:
self.line_second_.set_data(self.times[:i], self.firing_rate[:i])
self.line_second_a.set_data(
self.times[head_slice], self.firing_rate[head_slice])
self.line_second_e.set_data(
self.times[head], self.firing_rate[head])
# set axis limits: 1. check user-defined
current_window = np.diff(self.spks.get_xlim())
default_window = (np.isclose(current_window, self.timewindow)
or np.isclose(current_window, self.simtime - self.start))[0]
# 3. change if necessary
if default_window:
xlims = self.spks.get_xlim()
if self.times[i] >= xlims[1]:
self.spks.set_xlim(
self.times[i] - self.timewindow, self.times[i])
self.second.set_xlim(
self.times[i] - self.timewindow, self.times[i])
elif self.times[i] <= xlims[0]:
self.spks.set_xlim(self.start, self.timewindow + self.start)
def _make_ticks(self, timewindow):
target_num_ticks = np.ceil(self.duration / timewindow * 5)
target_step = self.duration / target_num_ticks
idx_step = np.abs(self.steps-target_step).argmin()
step = self.steps[idx_step]
num_steps = int(self.duration / step) + 2
self.xticks = [self.start + i*step for i in range(num_steps)]
self.xlabels = [str(i) for i in self.xticks]
#-------------------------------------------------------------------------
# User interaction
def on_click(self, event):
if event.button == '2':
if self.pause:
self.pause = False
self.event_source.start()
else:
self.pause = True
self.event_source.stop()
def on_keyboard_press(self, kb_event):
if kb_event.key == ' ':
if self.pause:
self.pause = False
self.event_source.start()
else:
self.pause = True
self.event_source.stop()
else:
if kb_event.key in ('B', 'F', 'N', 'P'):
if self.pause:
self.pause = False
self.pause_after = True # stop at next iteration
self.event_source.start() # restart temporarily
if kb_event.key == 'F':
self.increment *= 2
elif kb_event.key == 'B':
self.increment = max(1, int(self.increment / 2))
self.event = kb_event
def on_keyboard_release(self, kb_event):
if kb_event.key in (' ', 'B', 'F', 'N', 'P'):
if self.pause_after:
self.pause = True
self.pause_after = False
self.event_source.stop() # pause again
self.event = None
def save_movie(self, filename, fps=30, video_encoder='html5', codec=None,
bitrate=-1, start=None, stop=None, interval=None,
num_frames=None, metadata=None):
'''
Save the animation to a movie file.
Parameters
----------
filename : :obj:`str`
Name of the file where the movie will be saved.
fps : int, optional (default: 30)
Frame per second.
video_encoder : :obj:`str`, optional (default 'html5')
Movie encoding format; either 'ffmpeg', 'html5', or 'imagemagick'.
codec : :obj:`str`, optional (default: None)
Codec to use for writing movie; if None, default `animation.codec`
from `matplotlib` will be used.
bitrate : int, optional (default: -1)
Controls size/quality tradeoff for movie. Default (-1) lets utility
auto-determine.
start : float, optional (default: initial time)
Start time, corresponding to the first spike time that will appear
on the video.
stop : float, optional (default: final time)
Stop time, corresponding to the last spike time that will appear
on the video.
interval : int, optional (default: None)
Timestep increment for each new frame. Default saves all
timesteps (often heavy). E.g. setting `interval` to 10 will make
the file 10 times lighter.
num_frames : int, optional (default: None)
Total number of frames that should be saved.
metadata : :obj:`dict`, optional (default: None)
Metadata for the video (e.g. 'title', 'artist', 'comment',
'copyright')
Notes
-----
* ``ffmpeg`` is required for 'ffmpeg' and 'html5' encoders.
To get available formats, type ``ffmpeg -formats`` in a terminal;
type ``ffmpeg -codecs | grep EV`` for available codecs.
* Imagemagick is required for 'imagemagick' encoder.
'''
print(interval, num_frames)
if interval is not None and num_frames is not None:
raise InvalidArgument("Incompatible arguments `interval` and "
"`num_frames` provided. Choose one.")
elif interval is None and num_frames is None:
self.increment = 1
self.save_count = self.num_frames
elif interval is None:
self.increment = max(1, int(self.num_frames / num_frames))
self.save_count = num_frames
else:
self.increment = interval
self.save_count = int(self.num_frames / interval)
start_frame = 0
stop_frame = self.save_count
if start is not None:
start_frame = int(start / self.dt / self.increment) + 1
self.spks.set_xlim(left=start)
self.second.set_xlim(left=start)
if stop is not None:
stop_frame = int(stop / self.dt / self.increment) + 1
self.spks.set_xlim(right=stop)
self.second.set_xlim(right=stop)
_save_movie(
self, filename, fps, video_encoder, codec, bitrate, metadata,
self.fig.dpi, start_frame, stop_frame)
[docs]class Animation2d(_SpikeAnimator, anim.FuncAnimation):
'''
Class to plot the raster plot, firing-rate, and average trajectory in
a 2D phase-space for a network activity.
'''
def __init__(self, source, multimeter, start=0., timewindow=None,
trace=5., x='time', y='V_m', sort_neurons=None,
network=None, interval=50, vector_field=False, **kwargs):
'''
Generate a SubplotAnimation instance to plot a network activity.
Parameters
----------
source : tuple
NEST gid of the ``spike_detector``(s) which recorded the network.
multimeter : tuple
NEST gid of the ``multimeter``(s) which recorded the network.
timewindow : double, optional (default: None)
Time window which will be shown for the spikes and self.second.
trace : double, optional (default: 5.)
Interval of time (ms) over which the data is overlayed in red.
x : str, optional (default: "time")
Name of the `x`-axis variable (must be either "time" or the name
of a NEST recordable in the `multimeter`).
y : str, optional (default: "V_m")
Name of the `y`-axis variable (must be either "time" or the name
of a NEST recordable in the `multimeter`).
vector_field : bool, optional (default: False)
Whether the :math:`\dot{x}` and :math:`\dot{y}` arrows should be
added to phase space. Requires additional 'dotx' and 'doty'
arguments which are user defined functions to compute the
derivatives of `x` and `x` in time. These functions take 3
parameters, which are `x`, `y`, and `time_dependent`, where the
last parameter is a list of doubles associated to recordables
from the neuron model (see example for details). These recordables
must be declared in a `time_dependent` parameter.
sort_neurons : str or list, optional (default: None)
Sort neurons using a topological property ("in-degree",
"out-degree", "total-degree" or "betweenness"), an activity-related
property ("firing_rate", 'B2') or a user-defined list of sorted
neuron ids. Sorting is performed by increasing value of the
`sort_neurons` property from bottom to top inside each group.
**kwargs : dict, optional (default: {})
Optional arguments such as 'make_rate', 'num_xarrows',
'num_yarrows', 'dotx', 'doty', 'time_dependent', 'recordables',
'arrow_scale'.
'''
import matplotlib.pyplot as plt
import nest
x = "times" if x == "time" else x
y = "times" if y == "time" else y
# get data
data_mm = nest.GetStatus(multimeter)[0]["events"]
self.times = data_mm["times"]
self.num_frames = len(self.times)
idx_start = np.where(self.times >= start)[0][0]
self.idx_start = idx_start
self.times = self.times[idx_start:]
dt = self.times[1] - self.times[0]
self.simtime = self.times[-1]
self.start = start
self.duration = self.simtime - start
self.trace = trace
self.vector_field = vector_field
if timewindow is None:
self.timewindow = self.duration
else:
self.timewindow = min(timewindow, self.duration)
# init _SpikeAnimator parent class (create figure and right axes)
if 'make_rate' not in kwargs:
kwargs['make_rate'] = True
super(Animation2d, self).__init__(
source, sort_neurons=sort_neurons, network=network,
**kwargs)
# Data and axis for phase-space
self.x = data_mm[x][idx_start:] / self.num_neurons
self.y = data_mm[y][idx_start:] / self.num_neurons
self.ps = plt.subplot2grid((2, 4), (0, 0), rowspan=2, colspan=2)
self.ps.grid(False)
# lines
self.line_ps_ = Line2D([], [], color='black')
self.line_ps_a = Line2D([], [], color='red', linewidth=2)
self.line_ps_e = Line2D(
[], [], color='red', marker='o', markeredgecolor='r')
lines = [self.line_ps_, self.line_ps_a, self.line_ps_e]
xlim = (_min_axis(self.x.min()), _max_axis(self.x.max()))
self.set_axis(
self.ps, xlabel=_convert_axis(x), ylabel=_convert_axis(y),
lines=lines, xdata=self.x, ydata=self.y, xlim=xlim)
# For quiver plot (vector field)
nx = kwargs.get('num_xarrows', 20)
ny = kwargs.get('num_yarrows', 20)
scale = kwargs.get('arrow_scale', 30.)
if self.vector_field:
time_dependent_rec = kwargs.get('time_dependent', [])
self.time_dependent = [
data_mm[key][idx_start:] / self.num_neurons
for key in time_dependent_rec
]
self.dotx, self.doty = kwargs['dotx'], kwargs['doty']
xx = np.repeat(np.linspace(xlim[0], xlim[1], nx), ny)
yy = np.tile(np.linspace(self.y.min(), self.y.max(), ny), nx)
self.q = self.ps.quiver(xx, yy, [], [], scale=scale, color='grey')
plt.tight_layout()
anim.FuncAnimation.__init__(self, self.fig, self._draw, self._gen_data,
interval=interval, blit=True)
#-------------------------------------------------------------------------
# Animation instructions
def _gen_data(self):
i = -1
imax = len(self.x) - 1
while i < imax - self.increment:
if not self.pause:
if self.event is not None:
if self.event.key == 'N':
i += self.increment
elif self.event.key == 'P':
i -= self.increment
self.event = None
else:
i += self.increment
yield i
def _draw(self, framedata):
i = int(framedata)
head = i - 1
head_slice = ((self.times > (self.times[i] - self.trace))
& (self.times < self.times[i]))
spike_slice = ((self.spikes > (self.times[i] - self.trace))
& (self.spikes <= self.times[i]))
spike_cum = self.spikes < self.times[i]
lines = []
if self.vector_field:
time_dep = [arr[i] for arr in self.time_dependent]
u = self.dotx(self.q.X, self.q.Y, time_dep)
v = self.doty(self.q.X, self.q.Y, time_dep)
self.q.set_UVC(u, v)
lines.append(self.q)
self.line_ps_.set_data(self.x[:i], self.y[:i])
self.line_ps_a.set_data(self.x[head_slice], self.y[head_slice])
self.line_ps_e.set_data(self.x[i], self.y[i])
lines.extend([self.line_ps_, self.line_ps_a, self.line_ps_e,
self.line_spks_, self.line_spks_a, self.line_second_,
self.line_second_a, self.line_second_e])
super(Animation2d, self)._draw(
i, head, head_slice, spike_cum, spike_slice)
return lines
def _init_draw(self):
'''
Remove ticks from spks/second axes, save background,
then restore state to allow for moveable axes and labels.
'''
xlim = self.spks.get_xlim()
xlabel = self.spks.get_xlabel()
# remove
self.spks.set_xticks([])
self.spks.set_xticklabels([])
self.spks.set_xlabel("")
self.second.set_xticks([])
self.second.set_xticklabels([])
self.second.set_xlabel("")
# background
self.fig.canvas.draw()
self.bg = self.fig.canvas.copy_from_bbox(self.fig.bbox)
# restore
self.spks.set_xticks(self.xticks)
self.spks.set_xticklabels(self.xlabels)
self.spks.set_xlim(*xlim)
self.spks.set_xlabel(xlabel)
self.second.set_xticks(self.xticks)
self.second.set_xticklabels(self.xlabels)
self.second.set_xlim(*xlim)
self.second.set_xlabel(xlabel)
if self.vector_field:
self.q.set_UVC([], [])
# initialize empty lines
lines = [self.line_ps_, self.line_ps_a, self.line_ps_e,
self.line_spks_, self.line_spks_a,
self.line_second_, self.line_second_a, self.line_second_e]
for l in lines:
l.set_data([], [])
[docs]class AnimationNetwork(_SpikeAnimator, anim.FuncAnimation):
'''
Class to plot the raster plot, firing-rate, and space-embedded spiking
activity (neurons on the graph representation flash when spiking) in time.
'''
def __init__(self, source, network, resolution=1., start=0.,
timewindow=None, trace=5., show_spikes=False,
sort_neurons=None, decimate_connections=False,
interval=50, repeat=True, resting_size=None, active_size=None,
**kwargs):
'''
Generate a SubplotAnimation instance to plot a network activity.
Parameters
----------
source : tuple
NEST gid of the ``spike_detector``(s) which recorded the network.
network : :class:`~nngt.SpatialNetwork`
Network embedded in space to plot the actvity of the neurons in
space.
resolution : double, optional (default: None)
Time resolution of the animation.
timewindow : double, optional (default: None)
Time window which will be shown for the spikes and self.second.
trace : double, optional (default: 5.)
Interval of time (ms) over which the data is overlayed in red.
show_spikes : bool, optional (default: True)
Whether a spike trajectory should be displayed on the network.
sort_neurons : str or list, optional (default: None)
Sort neurons using a topological property ("in-degree",
"out-degree", "total-degree" or "betweenness"), an activity-related
property ("firing_rate", 'B2') or a user-defined list of sorted
neuron ids. Sorting is performed by increasing value of the
`sort_neurons` property from bottom to top inside each group.
**kwargs : dict, optional (default: {})
Optional arguments such as 'make_rate', or all arguments for the
:func:`nngt.plot.draw_network`.
'''
import matplotlib.pyplot as plt
import nest
from nngt.simulation.nest_activity import _get_data
self.network = weakref.ref(network)
self.simtime = _get_data(source)[-1, 1]
self.times = np.arange(start, self.simtime + resolution, resolution)
self.num_frames = len(self.times)
self.start = start
self.duration = self.simtime - start
self.trace = trace
self.show_spikes = show_spikes
if timewindow is None:
self.timewindow = self.duration
else:
self.timewindow = min(timewindow, self.duration)
# init _SpikeAnimator parent class (create figure and right axes)
#~ self.decim_conn = 1 if decimate is not None else decimate
self.kwargs = kwargs
cs = kwargs.get('chunksize', 10000)
mpl.rcParams['agg.path.chunksize'] = cs
if 'make_rate' not in kwargs:
kwargs['make_rate'] = True
super(AnimationNetwork, self).__init__(
source, sort_neurons=sort_neurons, network=network,
**kwargs)
self.env = plt.subplot2grid((2, 4), (0, 0), rowspan=2, colspan=2)
# Data and axis for network representation
bbox = self.env.get_window_extent().transformed(
self.fig.dpi_scale_trans.inverted())
area_px = bbox.width * bbox.height * self.fig.dpi**2
# neuron size
n_size = (resting_size if resting_size is not None
else max(2, 0.5*np.sqrt(area_px/self.num_neurons)))
if active_size is None:
active_size = n_size + 2
pos = network.get_positions() # positions of the neurons
self.x = pos[:, 0]
self.y = pos[:, 1]
# neurons
self.line_neurons = Line2D(
[], [], ls='None', marker='o', color='black', ms=n_size, mew=0)
self.line_neurons_a = Line2D(
[], [], ls='None', marker='o', color='red', ms=active_size, mew=0)
self.lines_env = [self.line_neurons, self.line_neurons_a]
xlim = (_min_axis(self.x.min()), _max_axis(self.x.max()))
self.set_axis(self.env, xlabel='Network', ylabel='',
lines=self.lines_env, xdata=self.x, ydata=self.y, xlim=xlim)
# spike trajectory
if show_spikes:
self.line_st_a = Line2D([], [], color='red', linewidth=1)
self.line_st_e = Line2D(
[], [], color='red', marker='d', ms=2, markeredgecolor='r')
self.lines_env.extend((self.line_st_a, self.line_st_e))
# remove the axes and grid from env
self.env.set_xticks([])
self.env.set_yticks([])
self.env.set_xticklabels([])
self.env.set_yticklabels([])
self.env.grid(None)
plt.tight_layout()
anim.FuncAnimation.__init__(
self, self.fig, self._draw, self._gen_data, repeat=repeat,
interval=interval, blit=True)
#-------------------------------------------------------------------------
# Animation instructions
def _gen_data(self):
i = -1
imax = len(self.times) - 1
while i < imax - self.increment:
if not self.pause:
if self.event is not None:
if self.event.key == 'N':
i += self.increment
elif self.event.key == 'P':
i -= self.increment
else:
i += self.increment
yield i
def _draw(self, framedata):
i = int(framedata)
if i == 0: # initialize neurons and connections
self.line_neurons.set_data(self.x, self.y)
#~ self.line_connections.set_data(self.x_conn, self.y_conn)
head = i - 1
head_slice = ((self.times > self.times[i] - self.trace)
& (self.times < self.times[i]))
spike_slice = ((self.spikes > self.times[i] - self.trace)
& (self.spikes <= self.times[i]))
spike_cum = self.spikes < self.times[i]
pos_ids = self.network().id_from_nest_gid(self.senders[spike_slice])
self.line_neurons_a.set_data(self.x[pos_ids], self.y[pos_ids])
if self.show_spikes:
# @todo: make this work for heterogeneous delays
time = self.times[i]
delays = np.average(self.network().get_delays())
departures = self.spikes[spikes_slice]
arrivals = departures + delays
# get the spikers
ids_dep = self.nids[self.senders[spikes_slice]]
degrees = network.get_degrees('out', node_list=ids_dep)
ids_dep = np.repeat(ids_dep, degrees) # repeat based on out-degree
x_dep = self.x[ids_dep]
y_dep = self.y[ids_dep]
# get their out-neighbours
#~ for d, a in zip(departures, arrivals):
super(AnimationNetwork, self)._draw(
i, head, head_slice, spike_cum, spike_slice)
return [self.line_neurons, self.line_neurons_a, self.line_spks_,
self.line_spks_a, self.line_second_, self.line_second_a,
self.line_second_e]
def _init_draw(self):
'''
Remove ticks from spks/second axes, save background,
then restore state to allow for moveable axes and labels.
'''
# remove
xlim = self.spks.get_xlim()
xlabel = self.spks.get_xlabel()
self.spks.set_xticks([])
self.spks.set_xticklabels([])
self.spks.set_xlabel("")
self.second.set_xticks([])
self.second.set_xticklabels([])
self.second.set_xlabel("")
# background
self.fig.canvas.draw()
self.bg = self.fig.canvas.copy_from_bbox(self.fig.bbox)
# restore
self.spks.set_xticks(self.xticks)
self.spks.set_xticklabels(self.xlabels)
self.spks.set_xlim(*xlim)
self.spks.set_xlabel(xlabel)
self.second.set_xticks(self.xticks)
self.second.set_xticklabels(self.xlabels)
self.second.set_xlim(*xlim)
self.second.set_xlabel(xlabel)
# initialize empty lines
lines = [self.line_spks_, self.line_spks_a, self.line_neurons_a,
self.line_second_, self.line_second_a, self.line_second_e,
self.line_neurons]
for l in lines:
l.set_data([], [])
# initialize the neurons and connections between neurons
draw_network(self.network(), ncolor='k', axis=self.env, show=False,
simple_nodes=True, decimate=-1, tight=False, **self.kwargs)
if self.network().is_spatial():
shape = self.network().shape
xmin, ymin, xmax, ymax = shape.bounds
dx = 0.02*(xmax-xmin)
dy = 0.02*(ymax-ymin)
self.env.set_xlim(xmin-dx, xmax+dx)
self.env.set_ylim(ymin-dy, ymax+dy)
self.line_neurons = self.env.lines[0]
#~ self.line_neurons.set_data(self.x, self.y)
#~ num_edges = self.network().edge_nb()
#~ self.x_conn = np.zeros(3*num_edges)
#~ self.y_conn = np.zeros(3*num_edges)
#~ adj_mat = self.network().adjacency_matrix()
#~ edges = adj_mat.nonzero()
#~ self.x_conn[::3] = self.x[edges[0]] # x position of source nodes
#~ self.x_conn[1::3] = self.x[edges[1]] # x position of target nodes
#~ self.x_conn[2::3] = np.NaN # NaN to separate
#~ self.y_conn[::3] = self.y[edges[0]] # y position of source nodes
#~ self.y_conn[1::3] = self.y[edges[1]] # y position of target nodes
#~ self.y_conn[2::3] = np.NaN # NaN to separate
#~ self.env.plot(
#~ self.x_conn[::self.decim_conn], self.y_conn[::self.decim_conn],
#~ color='k', alpha=0.3, lw=1)
# ----- #
# Tools #
# ----- #
def _max_axis(value, min_val=0.):
if np.isclose(value, 0.):
return -0.02*min_val
elif np.sign(value) > 0.:
return 1.02*value
else:
return 0.98*value
def _min_axis(value, max_val=0.):
if np.isclose(value, 0.):
return -0.02*max_val
elif np.sign(value) < 0.:
return 1.02*value
else:
return 0.98*value
def _convert_axis(axis_name):
lowercase = axis_name.lower()
if lowercase == "times":
return "Time (ms)"
new_name = "$"
i = axis_name.find("_")
if i != -1:
start = lowercase[:i]
if start in ("tau", "alpha", "beta", "gamma", "delta"):
new_name += "\\" + axis_name[:i] + "_{" + axis_name[i+1:] + "}$"
elif start in ("v", "e"):
new_name += axis_name[:i] + "_{" + axis_name[i+1:] + "}$ (mV)"
elif start == "i":
new_name += axis_name[:i] + "_{" + axis_name[i+1:] + "}$ (pA)"
else:
new_name += axis_name[:i] + "_{" + axis_name[i+1:] + "}$"
else:
if lowercase in ("tau", "alpha", "beta", "gamma", "delta"):
new_name += "\\" + lowercase + "$"
elif lowercase == "w":
new_name = "$w$ (pA)"
else:
new_name += lowercase + "$"
return new_name
def _save_movie(animation, filename, fps, video_encoder, codec, bitrate,
metadata, dpi, start, stop):
if filename.endswith('.mp4') or filename.endswith('.avi'):
ffcodec = 'h264' if filename.endswith('.mp4') else 'xvid'
fig = animation.fig
canvas_width, canvas_height = fig.get_size_inches()*fig.dpi
# Open an ffmpeg process
cmdstring = ('ffmpeg',
'-y', '-r', str(fps), # overwrite, 1fps
'-s', '%dx%d' % (canvas_width, canvas_height), # size of image string
'-pix_fmt', 'argb', # format
'-f', 'rawvideo', '-i', '-', # tell ffmpeg to expect raw video from the pipe
'-vcodec', ffcodec, filename) # output encoding
p = subprocess.Popen(cmdstring, stdin=subprocess.PIPE)
# Draw frames and write to the pipe
for i in range(start, stop):
frame = int(i*animation.increment)
# draw the frame
animation._draw(frame)
fig.canvas.draw()
# extract the image as an ARGB string
string = fig.canvas.tostring_argb()
# write to pipe
p.stdin.write(string)
# Finish up
p.communicate()
animation._init_draw()
else:
if metadata is None:
metadata = {"artist": "NNGT"}
encoder = 'ffmpeg' if video_encoder == 'html5' else video_encoder
Writer = anim.writers[encoder]
if video_encoder == 'html5':
codec = 'libx264'
writer = Writer(codec=codec, fps=fps, bitrate=bitrate, metadata=metadata)
animation.save(filename, writer=writer, dpi=dpi)
def _vector_field(q, dotx_func, doty_func, x, y, Is):
'''
Add the vector field of the x and y derivatives in phase space.
Parameters
----------
q : :class:`matplotlib.quiver.Quiver`
Phase space quiver object.
dotx_func : function
User provided function giving :math:`\dot{x} = f(x, y, Is(t))`.
doty_func : function
User provided function giving :math:`\dot{y} = g(x, y, Is(t))`.
x : :class:`numpy.ndarray`.
y : :class:`numpy.ndarray`.
Is : float
Current (time dependent data).
'''
q.set_UVC(dotx_func(x, y, Is), doty_func(x, y, Is))