#!/usr/bin/env python
#-*- coding:utf-8 -*-
#
# This file is part of the NNGT project to generate and analyze
# neuronal networks and their activity.
# Copyright (C) 2015-2019 Tanguy Fardet
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
""" Configuration tools for NNGT """
import os
import sys
import logging
from copy import deepcopy
import numpy as np
import nngt
from .errors import InvalidArgument
from .logger import _configure_logger, _init_logger, _log_message
from .reloading import reload_module
from .rng_tools import seed as nngt_seed
from .test_functions import mpi_checker, num_mpi_processes, mpi_barrier
logger = logging.getLogger(__name__)
# ----------------- #
# Getter and setter #
# ----------------- #
[docs]def get_config(key=None, detailed=False):
'''
Get the NNGT configuration as a dictionary.
Note
----
This function has no MPI barrier on it.
'''
if key is None:
cfg = {key: val for key, val in nngt._config.items()}
if detailed:
return cfg
else:
# hide database config if not used
rm = []
if not nngt._config["use_database"]:
for k in cfg:
if k.startswith('db_'):
rm.append(k)
# hide mpi conf if not used
if not nngt._config["mpi"]:
del cfg['mpi_comm']
# hide technical stuff
del cfg["load_nest"]
del cfg["graph"]
del cfg["library"]
del cfg["palette"]
del cfg["use_tex"]
del cfg["mpl_backend"]
del cfg["color_lib"]
# hide log config
for k in cfg:
if k.startswith('log_'):
rm.append(k)
for k in rm:
del cfg[k]
return cfg
else:
res = nngt._config[key]
return res
[docs]@mpi_barrier
def set_config(config, value=None, silent=False):
'''
Set NNGT's configuration.
Parameters
----------
config : dict or str
Either a full configuration dictionary or one key to be set together
with its associated value.
value : object, optional (default: None)
Value associated to `config` if `config` is a key.
Examples
--------
>>> nngt.set_config({'multithreading': True, 'omp': 4})
>>> nngt.set_config('multithreading', False)
Notes
-----
See the config file `nngt/nngt.conf.default` or `~/.nngt/nngt.conf` for
details about your configuration.
This function has an MPI barrier on it, so it must always be called on all
processes.
See also
--------
:func:`~nngt.get_config`
'''
old_mt = nngt._config["multithreading"]
old_mpi = nngt._config["mpi"]
old_omp = nngt._config["omp"]
old_gl = nngt._config["backend"]
old_msd = nngt._config["msd"]
old_config = nngt._config.copy()
new_config = None
if not isinstance(config, dict):
new_config = {config: value}
else:
new_config = config.copy()
for key in new_config:
if key not in nngt._config:
raise KeyError("Unknown configuration property: {}".format(key))
if key == "log_level":
new_config[key] = _convert(new_config[key])
if key == "backend" and new_config[key] != old_gl:
nngt.use_backend(new_config[key])
if key == "log_folder":
new_config["log_folder"] = os.path.abspath(
os.path.expanduser(new_config["log_folder"]))
if key == "db_folder":
new_config["db_folder"] = os.path.abspath(
os.path.expanduser(new_config["db_folder"]))
# check multithreading status and number of threads
_pre_update_parallelism(new_config, old_mt, old_omp, old_mpi)
# update
nngt._config.update(new_config)
# apply multithreading parameters
_post_update_parallelism(new_config, old_gl, old_msd, old_mt, old_mpi)
# update matplotlib
if nngt._config['use_tex']:
import matplotlib
matplotlib.rc('text', usetex=True)
# update database
if nngt._config["use_database"] and not hasattr(nngt, "db"):
from .. import database
sys.modules["nngt.database"] = database
if nngt._config["db_to_file"]:
_log_message(logger, "WARNING",
"This functionality is not available")
# update nest
if nngt._config["load_nest"] or nngt._config["with_nest"]:
try:
import nest as _nest
from .. import simulation
sys.modules["nngt.simulation"] = simulation
nngt._config["with_nest"] = _nest.version()
except:
nngt._config["with_nest"] = False
# log changes
_configure_logger(nngt._logger)
glib = (nngt._config["library"] if nngt._config["library"] is not None
else nngt)
num_mpi = num_mpi_processes()
s_mpi = False if not nngt._config["mpi"] else "True ({} process{})".format(
num_mpi, "es" if num_mpi > 1 else "")
try:
import svg.path
has_svg = True
except:
has_svg = False
try:
import dxfgrabber
has_dxf = True
except:
has_dxf = False
try:
import shapely
has_shapely = shapely.__version__
except:
has_shapely = False
conf_info = config_info.format(
gl = nngt._config["backend"] + " " + glib.__version__[:5],
thread = nngt._config["multithreading"],
plot = nngt._config["with_plot"],
nest = nngt._config["with_nest"],
db = nngt._config["use_database"],
omp = nngt._config["omp"],
s = "s" if nngt._config["omp"] > 1 else "",
mpi = s_mpi,
shapely = has_shapely,
svg = has_svg,
dxf = has_dxf,
)
if not silent and old_config != nngt._config:
_log_conf_changed(conf_info)
# ----- #
# Tools #
# ----- #
def _convert(value):
value = str(value)
if value.isdigit():
return int(value)
elif value.lower() == "true":
return True
elif value.lower() == "false":
return False
elif value.upper() == "CRITICAL":
return logging.CRTICAL
elif value.upper() == "DEBUG":
return logging.DEBUG
elif value.upper() == "ERROR":
return logging.ERROR
elif value.upper() == "INFO":
return logging.INFO
elif value.upper() == "WARNING":
return logging.WARNING
else:
return value
def _load_config(path_config):
''' Load `~/.nngt.conf` and parse it, return the settings '''
with open(path_config, 'r') as fconfig:
options = [l.strip() for l in fconfig if l.strip() and l[0] != "#"]
for opt in options:
sep = opt.find("=")
opt_name = opt[:sep].strip()
nngt._config[opt_name] = _convert(opt[sep+1:].strip())
_init_logger(nngt._logger)
@mpi_checker(logging=True)
def _log_conf_changed(conf_info):
logger.info(conf_info)
def _set_gt_config(old_gl, new_config):
using_gt = old_gl == "graph-tool"
using_gt *= new_config.get("backend", old_gl) == "graph-tool"
using_gt *= nngt._config["library"] is not None
if "omp" in new_config and using_gt:
omp_nest = new_config["omp"]
if nngt._config['with_nest']:
import nest
omp_nest = nest.GetKernelStatus("local_num_threads")
if omp_nest == new_config["omp"]:
nngt._config["library"].openmp_set_num_threads(nngt._config["omp"])
else:
_log_message(logger, "WARNING",
"Using NEST and graph_tool, OpenMP number must be "
"consistent throughout the code. Current NEST "
"config states omp = " + str(omp_nest) + ", hence "
"`graph_tool` configuration was not changed.")
def _pre_update_parallelism(new_config, old_mt, old_omp, old_mpi):
mt = "multithreading"
if "omp" in new_config:
if new_config["omp"] > 1:
if mt in new_config and not new_config[mt]:
_log_message(logger, "WARNING",
"Updating to 'multithreading' == False with "
"'omp' greater than one.")
elif mt not in new_config and not old_mt:
new_config[mt] = True
_log_message(logger, "WARNING",
"'multithreading' was set to False but new "
"'omp' is greater than one. Updating "
"'multithreading' to True.")
if new_config.get('mpi', False) and new_config.get(mt, False):
raise InvalidArgument('Cannot set both "mpi" and "multithreading" to '
'True simultaneously, choose one or the other.')
elif new_config.get(mt, False):
new_config['mpi'] = False
elif new_config.get('mpi', False):
if old_mt:
new_config[mt] = False
_log_message(logger, "WARNING",
'"mpi" set to True but previous configuration was '
'using OpenMP; setting "multithreading" to False '
'to switch to mpi algorithms.')
with_mt = new_config.get(mt, old_mt)
with_mpi = new_config.get('mpi', old_mpi)
# check that seeds are correct
if new_config.get('seeds', None) is not None:
seeds = new_config['seeds']
err = 'Expected {} seeds.'
err2 = 'All seeds must be different.'
if with_mpi:
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
assert size == len(seeds), err.format(size)
assert len(set(seeds)) == len(seeds), err2
elif with_mt:
num_omp = new_config.get("omp", old_omp)
assert num_omp == len(seeds), err.format(num_omp)
assert len(set(seeds)) == len(seeds), err2
else:
# reset seeds if necessary
# - because the number of threads changed
reset_seeds = (new_config.get("omp", 1) != nngt._config["omp"])
# - because we switched from OpenMP to MPI
reset_seeds += (with_mpi and old_mt)
# - because we switched from MPI to OpenMP
reset_seeds += (with_mt and old_mpi)
if reset_seeds:
new_config['seeds'] = None
new_config['msd'] = None
nngt._seeded = False
def _post_update_parallelism(new_config, old_gl, old_msd, old_mt, old_mpi):
# reload for omp
new_multithreading = new_config.get("multithreading", old_mt)
if new_multithreading != old_mt:
reload_module(sys.modules["nngt"].generation.graph_connectivity)
# if multithreading loading failed, set omp back to 1
if not nngt._config['multithreading']:
nngt._config['omp'] = 1
nngt._config['seeds'] = None
# if MPI is on, set mpi_comm and check random numbers
if new_config.get('mpi', old_mpi):
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nngt._config['mpi_comm'] = comm
# check that master seed is the same everywhere
msd = nngt._config['msd']
msd = comm.gather(msd, root=0)
if rank == 0:
if None not in msd:
msd = np.array(msd, dtype=int)
if not np.alltrue(msd == msd[0]):
nngt._config["mpi"] = False
raise InvalidArgument("'msd' entry must be the same on "
"all MPI processes.")
else:
differs = [seed != None for seed in msd]
if np.any(differs):
raise InvalidArgument("'msd' entry must be the same on "
"all MPI processes.")
# reload for mpi
if new_config.get('mpi', old_mpi) != old_mpi:
reload_module(sys.modules["nngt"].generation.graph_connectivity)
# set graph-tool config
_set_gt_config(old_gl, new_config)
# seed python RNGs
if old_msd != nngt._config['msd'] or not nngt._seeded:
nngt_seed(msd=nngt._config['msd'])
config_info = '''
# -------------- #
# Config changed #
# -------------- #
Graph library: {gl}
Multithreading: {thread} ({omp} thread{s})
MPI: {mpi}
Plotting: {plot}
NEST support: {nest}
Shapely: {shapely}
SVG support: {svg}
DXF support: {dxf}
Database: {db}
'''