Source code for nngt.lib.nngt_config

#!/usr/bin/env python
#-*- coding:utf-8 -*-
#
# This file is part of the NNGT project to generate and analyze
# neuronal networks and their activity.
# Copyright (C) 2015-2017  Tanguy Fardet
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

""" Configuration tools for NNGT """

import os
import sys
import logging
from copy import deepcopy

import numpy as np

import nngt
from .errors import InvalidArgument
from .logger import _configure_logger, _init_logger, _log_message
from .reloading import reload_module
from .rng_tools import seed as nngt_seed
from .test_functions import mpi_checker, num_mpi_processes, mpi_barrier


logger = logging.getLogger(__name__)


# ----------------- #
# Getter and setter #
# ----------------- #

[docs]def get_config(key=None, detailed=False): ''' Get the NNGT configuration as a dictionary. Note ---- This function has no MPI barrier on it. ''' if key is None: cfg = {key: val for key, val in nngt._config.items()} if detailed: return cfg else: # hide database config if not used rm = [] if not nngt._config["use_database"]: for k in cfg: if k.startswith('db_'): rm.append(k) # hide mpi conf if not used if not nngt._config["mpi"]: del cfg['mpi_comm'] # hide technical stuff del cfg["load_nest"] del cfg["graph"] del cfg["library"] del cfg["palette"] del cfg["use_tex"] del cfg["mpl_backend"] del cfg["color_lib"] # hide log config for k in cfg: if k.startswith('log_'): rm.append(k) for k in rm: del cfg[k] return cfg else: res = nngt._config[key] return res
[docs]@mpi_barrier def set_config(config, value=None, silent=False): ''' Set NNGT's configuration. Parameters ---------- config : dict or str Either a full configuration dictionary or one key to be set together with its associated value. value : object, optional (default: None) Value associated to `config` if `config` is a key. Examples -------- >>> nngt.set_config({'multithreading': True, 'omp': 4}) >>> nngt.set_config('multithreading', False) Notes ----- See the config file `nngt/nngt.conf.default` or `~/.nngt/nngt.conf` for details about your configuration. This function has an MPI barrier on it, so it must always be called on all processes. See also -------- :func:`~nngt.get_config` ''' old_mt = nngt._config["multithreading"] old_mpi = nngt._config["mpi"] old_omp = nngt._config["omp"] old_gl = nngt._config["backend"] old_msd = nngt._config["msd"] old_config = nngt._config.copy() new_config = None if not isinstance(config, dict): new_config = {config: value} else: new_config = config.copy() for key in new_config: if key not in nngt._config: raise KeyError("Unknown configuration property: {}".format(key)) if key == "log_level": new_config[key] = _convert(new_config[key]) if key == "backend" and new_config[key] != old_gl: nngt.use_backend(new_config[key]) if key == "log_folder": new_config["log_folder"] = os.path.abspath( os.path.expanduser(new_config["log_folder"])) if key == "db_folder": new_config["db_folder"] = os.path.abspath( os.path.expanduser(new_config["db_folder"])) # check multithreading status and number of threads _pre_update_parallelism(new_config, old_mt, old_omp, old_mpi) # update nngt._config.update(new_config) # apply multithreading parameters _post_update_parallelism(new_config, old_gl, old_msd, old_mt, old_mpi) # update matplotlib if nngt._config['use_tex']: import matplotlib matplotlib.rc('text', usetex=True) # update database if nngt._config["use_database"] and not hasattr(nngt, "db"): from .. import database sys.modules["nngt.database"] = database if nngt._config["db_to_file"]: _log_message(logger, "WARNING", "This functionality is not available") # log changes _configure_logger(nngt._logger) glib = (nngt._config["library"] if nngt._config["library"] is not None else nngt) num_mpi = num_mpi_processes() s_mpi = False if not nngt._config["mpi"] else "True ({} process{})".format( num_mpi, "es" if num_mpi > 1 else "") try: import svg.path has_svg = True except: has_svg = False try: import dxfgrabber has_dxf = True except: has_dxf = False try: import shapely has_shapely = shapely.__version__ except: has_shapely = False conf_info = config_info.format( gl = nngt._config["backend"] + " " + glib.__version__[:5], thread = nngt._config["multithreading"], plot = nngt._config["with_plot"], nest = nngt._config["with_nest"], db = nngt._config["use_database"], omp = nngt._config["omp"], s = "s" if nngt._config["omp"] > 1 else "", mpi = s_mpi, shapely = has_shapely, svg = has_svg, dxf = has_dxf, ) if not silent and old_config != nngt._config: _log_conf_changed(conf_info)
# ----- # # Tools # # ----- # def _convert(value): value = str(value) if value.isdigit(): return int(value) elif value.lower() == "true": return True elif value.lower() == "false": return False elif value.upper() == "CRITICAL": return logging.CRTICAL elif value.upper() == "DEBUG": return logging.DEBUG elif value.upper() == "ERROR": return logging.ERROR elif value.upper() == "INFO": return logging.INFO elif value.upper() == "WARNING": return logging.WARNING else: return value def _load_config(path_config): ''' Load `~/.nngt.conf` and parse it, return the settings ''' with open(path_config, 'r') as fconfig: options = [l.strip() for l in fconfig if l.strip() and l[0] != "#"] for opt in options: sep = opt.find("=") opt_name = opt[:sep].strip() nngt._config[opt_name] = _convert(opt[sep+1:].strip()) _init_logger(nngt._logger) @mpi_checker(logging=True) def _log_conf_changed(conf_info): logger.info(conf_info) def _set_gt_config(old_gl, new_config): using_gt = old_gl == "graph-tool" using_gt *= new_config.get("backend", old_gl) == "graph-tool" using_gt *= nngt._config["library"] is not None if "omp" in new_config and using_gt: omp_nest = new_config["omp"] if nngt._config['with_nest']: import nest omp_nest = nest.GetKernelStatus("local_num_threads") if omp_nest == new_config["omp"]: nngt._config["library"].openmp_set_num_threads(nngt._config["omp"]) else: _log_message(logger, "WARNING", "Using NEST and graph_tool, OpenMP number must be " "consistent throughout the code. Current NEST " "config states omp = " + str(omp_nest) + ", hence " "`graph_tool` configuration was not changed.") def _pre_update_parallelism(new_config, old_mt, old_omp, old_mpi): mt = "multithreading" if "omp" in new_config: if new_config["omp"] > 1: if mt in new_config and not new_config[mt]: _log_message(logger, "WARNING", "Updating to 'multithreading' == False with " "'omp' greater than one.") elif mt not in new_config and not old_mt: new_config[mt] = True _log_message(logger, "WARNING", "'multithreading' was set to False but new " "'omp' is greater than one. Updating " "'multithreading' to True.") if new_config.get('mpi', False) and new_config.get(mt, False): raise InvalidArgument('Cannot set both "mpi" and "multithreading" to ' 'True simultaneously, choose one or the other.') elif new_config.get(mt, False): new_config['mpi'] = False elif new_config.get('mpi', False): if old_mt: new_config[mt] = False _log_message(logger, "WARNING", '"mpi" set to True but previous configuration was ' 'using OpenMP; setting "multithreading" to False ' 'to switch to mpi algorithms.') with_mt = new_config.get(mt, old_mt) with_mpi = new_config.get('mpi', old_mpi) # check that seeds are correct if new_config.get('seeds', None) is not None: seeds = new_config['seeds'] err = 'Expected {} seeds.' err2 = 'All seeds must be different.' if with_mpi: from mpi4py import MPI comm = MPI.COMM_WORLD size = comm.Get_size() assert size == len(seeds), err.format(size) assert len(set(seeds)) == len(seeds), err2 elif with_mt: num_omp = new_config.get("omp", old_omp) assert num_omp == len(seeds), err.format(num_omp) assert len(set(seeds)) == len(seeds), err2 else: # reset seeds if necessary # - because the number of threads changed reset_seeds = (new_config.get("omp", 1) != nngt._config["omp"]) # - because we switched from OpenMP to MPI reset_seeds += (with_mpi and old_mt) # - because we switched from MPI to OpenMP reset_seeds += (with_mt and old_mpi) if reset_seeds: new_config['seeds'] = None new_config['msd'] = None nngt._seeded = False def _post_update_parallelism(new_config, old_gl, old_msd, old_mt, old_mpi): # reload for omp new_multithreading = new_config.get("multithreading", old_mt) if new_multithreading != old_mt: reload_module(sys.modules["nngt"].generation.graph_connectivity) # if multithreading loading failed, set omp back to 1 if not nngt._config['multithreading']: nngt._config['omp'] = 1 nngt._config['seeds'] = None # if MPI is on, set mpi_comm and check random numbers if new_config.get('mpi', old_mpi): from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() nngt._config['mpi_comm'] = comm # check that master seed is the same everywhere msd = nngt._config['msd'] msd = comm.gather(msd, root=0) if rank == 0: if None not in msd: msd = np.array(msd, dtype=int) if not np.alltrue(msd == msd[0]): nngt._config["mpi"] = False raise InvalidArgument("'msd' entry must be the same on " "all MPI processes.") else: differs = [seed != None for seed in msd] if np.any(differs): raise InvalidArgument("'msd' entry must be the same on " "all MPI processes.") # reload for mpi if new_config.get('mpi', old_mpi) != old_mpi: reload_module(sys.modules["nngt"].generation.graph_connectivity) # set graph-tool config _set_gt_config(old_gl, new_config) # seed python RNGs if old_msd != nngt._config['msd'] or not nngt._seeded: nngt_seed(msd=nngt._config['msd']) config_info = ''' # -------------- # # Config changed # # -------------- # Graph library: {gl} Multithreading: {thread} ({omp} thread{s}) MPI: {mpi} Plotting: {plot} NEST support: {nest} Shapely: {shapely} SVG support: {svg} DXF support: {dxf} Database: {db} '''