# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet
# SPDX-License-Identifier: GPL-3.0-or-later
# nngt/lib/test_functions.py
""" Test functions for the NNGT """
import collections
import functools
import inspect
import warnings
from collections.abc import Container as _container
from collections.abc import Iterable as _iterable
from collections.abc import KeysView as _key_view
from collections.abc import ValuesView as _value_view
import numpy as np
import nngt
from .decorator import decorate
def deprecated(version, reason=None, alternative=None, removal=None):
'''
Decorator to mark deprecated functions.
'''
def decorator(func):
def wrapper(func, *args, **kwargs):
# turn off filter temporarily
warnings.simplefilter('always', DeprecationWarning)
message = "Function {} is deprecated since version {}"
message = message.format(func.__name__, version)
if reason is not None:
message += " because " + reason + "."
else:
message += "."
if removal is not None:
message += " It will be removed in version {}.".format(removal)
if alternative is not None:
message += " Use " + alternative + " instead."
warnings.warn(message, category=DeprecationWarning)
warnings.simplefilter('default', DeprecationWarning)
return func(*args, **kwargs)
return decorate(func, wrapper)
return decorator
# ~ def constrain_argument(arg, valid_values):
# ~ '''
# ~ Decorator to constrain function arguments.
# ~ '''
# ~ def decorator(func):
# ~ @functools.wrap(func)
# ~ def wrapper(func, *args, **kwargs):
# ~ arg_names = inspect.getargspec(func)[0]
# ~ arguments = {k: v for k, v in zip(arg_names, args)}
# ~ arguments.update(kwargs)
# ~ if arg in arguments:
# ~ assert arguments[arg] in valid_values, \
# ~ "Argument `'{}' must be among {}".format(arg, valid_values)
# ~ return func(*args, **kwargs)
# ~ return decorate(func, wrapper)
# ~ return decorator
# ~ def check_arguments(arg_list):
# ~ '''
# ~ Decorator to constrain function arguments.
# ~ '''
# ~ def decorator(func):
# ~ @functools.wrap(func)
# ~ def wrapper(func, *args, **kwargs):
# ~ arg_names = set(inspect.getargspec(func)[0])
# ~ allowed = arg_names.union(arg_list)
# ~ for arg in kwargs:
# ~ if arg not in allowed:
# ~ raise InvalidArgument(("'{}' is not a valid argument for "
# ~ "function '{}'.").format(arg, func))
# ~ return func(*args, **kwargs)
# ~ return decorate(func, wrapper)
# ~ return decorator
[docs]def on_master_process():
'''
Check whether the current code is executing on the master process (rank 0)
if MPI is used.
Returns
-------
True if rank is 0, if mpi4py is not present or if MPI is not used,
otherwise False.
'''
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
if rank == 0:
return True
else:
return False
except ImportError:
return True
[docs]def num_mpi_processes():
''' Returns the number of MPI processes (1 if MPI is not used) '''
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
return comm.Get_size()
except ImportError:
return 1
def mpi_barrier(func=None):
def wrapper(func, *args, **kwargs):
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
comm.Barrier()
except ImportError:
pass
if func is not None:
return func(*args, **kwargs)
# act as a real decorator
if func is not None:
return decorate(func, wrapper)
# otherwise just execute the barrier
wrapper(None)
def mpi_checker(logging=False):
'''
Decorator used to check for mpi and make sure only rank zero is used
to store and generate the graph if the mpi algorithms are activated.
'''
def decorator(func):
def wrapper(func, *args, **kwargs):
# when using MPI, make sure everyone waits for the others
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
comm.Barrier()
except ImportError:
pass
# check backend ("nngt" is fully parallel, not the others)
backend = False
if not logging:
backend = nngt.get_config("backend") == "nngt"
if backend or on_master_process():
return func(*args, **kwargs)
else:
return None
return decorate(func, wrapper)
return decorator
def mpi_random(func):
'''
Decorator asserting that all processes start with same random seed when
using mpi.
'''
def wrapper(func, *args, **kwargs):
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
if rank == 0:
state = np.random.get_state()
else:
state = None
state = comm.bcast(state, root=0)
np.random.set_state(state)
except ImportError:
pass
return func(*args, **kwargs)
return decorate(func, wrapper)
[docs]def nonstring_container(obj):
'''
Returns true for any iterable which is not a string or byte sequence.
'''
if isinstance(obj, (_key_view, _value_view)):
return True
if not isinstance(obj, _container):
return False
if isinstance(obj, (bytes, str)):
return False
return True
[docs]def is_integer(obj):
''' Return whether the object is an integer '''
return isinstance(obj, (int, np.integer))
[docs]def is_iterable(obj):
''' Return whether the object is iterable '''
return isinstance(obj, _iterable)
def graph_tool_check(version_min):
'''
Raise an error for function not working with old versions of graph-tool.
'''
def decorator(func):
def wrapper(func, *args, **kwargs):
old_graph_tool = _old_graph_tool(version_min)
if old_graph_tool:
raise NotImplementedError('This function is not working for '
'graph-tool < ' + version_min + '.')
else:
return func(*args, **kwargs)
return decorate(func, wrapper) # to preserve the docstring info
return decorator
def _old_graph_tool(version_min):
'''
Check for old versions of graph-tool for which some functions are not
working.
'''
return (nngt.get_config('backend') == 'graph-tool'
and nngt.get_config('library').__version__[:4] < version_min)