# -*- coding: utf-8 -*-# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet# SPDX-License-Identifier: GPL-3.0-or-later# nngt/lib/test_functions.py""" Test functions for the NNGT """importcollectionsimportfunctoolsimportinspectimportwarningsfromcollections.abcimportContaineras_containerfromcollections.abcimportIterableas_iterablefromcollections.abcimportKeysViewas_key_viewfromcollections.abcimportValuesViewas_value_viewimportnumpyasnpimportnngtfrom.decoratorimportdecoratedefdeprecated(version,reason=None,alternative=None,removal=None):''' Decorator to mark deprecated functions. '''defdecorator(func):defwrapper(func,*args,**kwargs):# turn off filter temporarilywarnings.simplefilter('always',DeprecationWarning)message="Function {} is deprecated since version {}"message=message.format(func.__name__,version)ifreasonisnotNone:message+=" because "+reason+"."else:message+="."ifremovalisnotNone:message+=" It will be removed in version {}.".format(removal)ifalternativeisnotNone:message+=" Use "+alternative+" instead."warnings.warn(message,category=DeprecationWarning)warnings.simplefilter('default',DeprecationWarning)returnfunc(*args,**kwargs)returndecorate(func,wrapper)returndecorator
[docs]defon_master_process():''' Check whether the current code is executing on the master process (rank 0) if MPI is used. Returns ------- True if rank is 0, if mpi4py is not present or if MPI is not used, otherwise False. '''try:frommpi4pyimportMPIcomm=MPI.COMM_WORLDrank=comm.Get_rank()ifrank==0:returnTrueelse:returnFalseexceptImportError:returnTrue
[docs]defnum_mpi_processes():''' Returns the number of MPI processes (1 if MPI is not used) '''try:frommpi4pyimportMPIcomm=MPI.COMM_WORLDreturncomm.Get_size()exceptImportError:return1
defmpi_barrier(func=None):defwrapper(func,*args,**kwargs):try:frommpi4pyimportMPIcomm=MPI.COMM_WORLDcomm.Barrier()exceptImportError:passiffuncisnotNone:returnfunc(*args,**kwargs)# act as a real decoratoriffuncisnotNone:returndecorate(func,wrapper)# otherwise just execute the barrierwrapper(None)defmpi_checker(logging=False):''' Decorator used to check for mpi and make sure only rank zero is used to store and generate the graph if the mpi algorithms are activated. '''defdecorator(func):defwrapper(func,*args,**kwargs):# when using MPI, make sure everyone waits for the otherstry:frommpi4pyimportMPIcomm=MPI.COMM_WORLDcomm.Barrier()exceptImportError:pass# check backend ("nngt" is fully parallel, not the others)backend=Falseifnotlogging:backend=nngt.get_config("backend")=="nngt"ifbackendoron_master_process():returnfunc(*args,**kwargs)else:returnNonereturndecorate(func,wrapper)returndecoratordefmpi_random(func):''' Decorator asserting that all processes start with same random seed when using mpi. '''defwrapper(func,*args,**kwargs):try:frommpi4pyimportMPIcomm=MPI.COMM_WORLDrank=comm.Get_rank()ifrank==0:state=np.random.get_state()else:state=Nonestate=comm.bcast(state,root=0)np.random.set_state(state)exceptImportError:passreturnfunc(*args,**kwargs)returndecorate(func,wrapper)
[docs]defnonstring_container(obj):''' Returns true for any iterable which is not a string or byte sequence. '''ifisinstance(obj,(_key_view,_value_view)):returnTrueifnotisinstance(obj,_container):returnFalseifisinstance(obj,(bytes,str)):returnFalsereturnTrue
[docs]defis_integer(obj):''' Return whether the object is an integer '''returnisinstance(obj,(int,np.integer))
[docs]defis_iterable(obj):''' Return whether the object is iterable '''returnisinstance(obj,_iterable)
defgraph_tool_check(version_min):''' Raise an error for function not working with old versions of graph-tool. '''defdecorator(func):defwrapper(func,*args,**kwargs):old_graph_tool=_old_graph_tool(version_min)ifold_graph_tool:raiseNotImplementedError('This function is not working for ''graph-tool < '+version_min+'.')else:returnfunc(*args,**kwargs)returndecorate(func,wrapper)# to preserve the docstring inforeturndecoratordef_old_graph_tool(version_min):''' Check for old versions of graph-tool for which some functions are not working. '''return(nngt.get_config('backend')=='graph-tool'andnngt.get_config('library').__version__[:4]<version_min)