numba_mpi.common
variables used across API implementation
1"""variables used across API implementation""" 2 3import ctypes 4import os 5import sys 6from ctypes.util import find_library 7from pathlib import Path 8 9import numba 10import numpy as np 11import psutil 12from mpi4py import MPI 13 14# pylint: disable=protected-access 15_MPI_Comm_World_ptr = MPI._addressof(MPI.COMM_WORLD) 16 17_MPI_DTYPES = { 18 np.dtype("uint8"): MPI._addressof(MPI.CHAR), 19 np.dtype("int32"): MPI._addressof(MPI.INT32_T), 20 np.dtype("int64"): MPI._addressof(MPI.INT64_T), 21 np.dtype("float"): MPI._addressof(MPI.FLOAT), 22 np.dtype("double"): MPI._addressof(MPI.DOUBLE), 23 np.dtype("complex64"): MPI._addressof(MPI.C_FLOAT_COMPLEX), 24 np.dtype("complex128"): MPI._addressof(MPI.C_DOUBLE_COMPLEX), 25} 26 27if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int): 28 _MpiComm = ctypes.c_int 29else: 30 _MpiComm = ctypes.c_void_p 31 32if MPI._sizeof(MPI.Datatype) == ctypes.sizeof(ctypes.c_int): 33 _MpiDatatype = ctypes.c_int 34 _MpiOp = ctypes.c_int 35else: 36 _MpiDatatype = ctypes.c_void_p 37 _MpiOp = ctypes.c_void_p 38 39if MPI._sizeof(MPI.Request) == ctypes.sizeof(ctypes.c_int): 40 RequestType = np.intc 41else: 42 RequestType = np.uintp 43 44# pylint: enable=protected-access 45_MpiStatusPtr = ctypes.c_void_p 46_MpiRequestPtr = ctypes.c_void_p 47 48 49# TODO: add proper handling of status objects 50@numba.njit 51def create_status_buffer(count=1): 52 """Helper function for creating numpy array storing pointers to MPI_Status results.""" 53 return np.empty(count * 5, dtype=np.intc) 54 55 56LIB = None 57names = ("mpich", "mpi", "msmpi", "impi") 58 59ps = psutil.Process(os.getpid()) 60windows = os.name == "nt" 61if hasattr(ps, "memory_maps"): 62 for dll in ps.memory_maps(): 63 path = Path(dll.path) 64 if windows or path.stem.startswith("lib"): 65 for name in names: 66 if name + ("" if windows else ".") in path.stem: 67 LIB = str(path) 68 break 69else: 70 for name in names: 71 LIB = find_library(name) 72 if LIB is not None: 73 break 74 75if LIB is None: 76 if sys.platform == "darwin": 77 raise RuntimeError( 78 """MPI library not found, if MPI was installed with Homebrew, export the following: 79 ARM: DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib:/usr/lib:$DYLD_FALLBACK_LIBRARY_PATH 80 Intel: DYLD_FALLBACK_LIBRARY_PATH="/usr/local/lib:/usr/lib:$DYLD_FALLBACK_LIBRARY_PATH 81 """ 82 ) 83 raise RuntimeError("no MPI library found") 84 85libmpi = ctypes.CDLL(LIB) 86 87send_recv_args = [ 88 ctypes.c_void_p, 89 ctypes.c_int, 90 _MpiDatatype, 91 ctypes.c_int, 92 ctypes.c_int, 93 _MpiComm, 94] 95 96send_recv_async_args = send_recv_args + [_MpiRequestPtr]
@numba.njit
def
create_status_buffer(count=1):
51@numba.njit 52def create_status_buffer(count=1): 53 """Helper function for creating numpy array storing pointers to MPI_Status results.""" 54 return np.empty(count * 5, dtype=np.intc)
Helper function for creating numpy array storing pointers to MPI_Status results.
LIB =
'/usr/lib/x86_64-linux-gnu/libmpich.so.12.4.1'
names =
('mpich', 'mpi', 'msmpi', 'impi')
ps =
psutil.Process(pid=3236, name='python', status='running')
windows =
False
libmpi =
<CDLL '/usr/lib/x86_64-linux-gnu/libmpich.so.12.4.1', handle 55ece2ba2630>
send_recv_args =
[<class 'ctypes.c_void_p'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>]
send_recv_async_args =
[<class 'ctypes.c_void_p'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_void_p'>]