numba_mpi.common

variables used across API implementation

 1"""variables used across API implementation"""
 2
 3import ctypes
 4import os
 5import sys
 6from ctypes.util import find_library
 7from pathlib import Path
 8
 9import numba
10import numpy as np
11import psutil
12from mpi4py import MPI
13
14# pylint: disable=protected-access
15_MPI_Comm_World_ptr = MPI._addressof(MPI.COMM_WORLD)
16
17_MPI_DTYPES = {
18    np.dtype("uint8"): MPI._addressof(MPI.CHAR),
19    np.dtype("int32"): MPI._addressof(MPI.INT32_T),
20    np.dtype("int64"): MPI._addressof(MPI.INT64_T),
21    np.dtype("float"): MPI._addressof(MPI.FLOAT),
22    np.dtype("double"): MPI._addressof(MPI.DOUBLE),
23    np.dtype("complex64"): MPI._addressof(MPI.C_FLOAT_COMPLEX),
24    np.dtype("complex128"): MPI._addressof(MPI.C_DOUBLE_COMPLEX),
25}
26
27if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int):
28    _MpiComm = ctypes.c_int
29else:
30    _MpiComm = ctypes.c_void_p
31
32if MPI._sizeof(MPI.Datatype) == ctypes.sizeof(ctypes.c_int):
33    _MpiDatatype = ctypes.c_int
34    _MpiOp = ctypes.c_int
35else:
36    _MpiDatatype = ctypes.c_void_p
37    _MpiOp = ctypes.c_void_p
38
39if MPI._sizeof(MPI.Request) == ctypes.sizeof(ctypes.c_int):
40    RequestType = np.intc
41else:
42    RequestType = np.uintp
43
44# pylint: enable=protected-access
45_MpiStatusPtr = ctypes.c_void_p
46_MpiRequestPtr = ctypes.c_void_p
47
48
49# TODO: add proper handling of status objects
50@numba.njit
51def create_status_buffer(count=1):
52    """Helper function for creating numpy array storing pointers to MPI_Status results."""
53    return np.empty(count * 5, dtype=np.intc)
54
55
56LIB = None
57names = ("mpich", "mpi", "msmpi", "impi")
58
59ps = psutil.Process(os.getpid())
60windows = os.name == "nt"
61if hasattr(ps, "memory_maps"):
62    for dll in ps.memory_maps():
63        path = Path(dll.path)
64        if windows or path.stem.startswith("lib"):
65            for name in names:
66                if name + ("" if windows else ".") in path.stem:
67                    LIB = str(path)
68                    break
69else:
70    for name in names:
71        LIB = find_library(name)
72        if LIB is not None:
73            break
74
75if LIB is None:
76    if sys.platform == "darwin":
77        raise RuntimeError(
78            """MPI library not found, if MPI was installed with Homebrew, export the following:
79            ARM: DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib:/usr/lib:$DYLD_FALLBACK_LIBRARY_PATH
80            Intel: DYLD_FALLBACK_LIBRARY_PATH="/usr/local/lib:/usr/lib:$DYLD_FALLBACK_LIBRARY_PATH
81            """
82        )
83    raise RuntimeError("no MPI library found")
84
85libmpi = ctypes.CDLL(LIB)
86
87send_recv_args = [
88    ctypes.c_void_p,
89    ctypes.c_int,
90    _MpiDatatype,
91    ctypes.c_int,
92    ctypes.c_int,
93    _MpiComm,
94]
95
96send_recv_async_args = send_recv_args + [_MpiRequestPtr]
@numba.njit
def create_status_buffer(count=1):
51@numba.njit
52def create_status_buffer(count=1):
53    """Helper function for creating numpy array storing pointers to MPI_Status results."""
54    return np.empty(count * 5, dtype=np.intc)

Helper function for creating numpy array storing pointers to MPI_Status results.

LIB = '/usr/lib/x86_64-linux-gnu/libmpich.so.12.4.1'
names = ('mpich', 'mpi', 'msmpi', 'impi')
ps = psutil.Process(pid=3236, name='python', status='running')
windows = False
libmpi = <CDLL '/usr/lib/x86_64-linux-gnu/libmpich.so.12.4.1', handle 55ece2ba2630>
send_recv_args = [<class 'ctypes.c_void_p'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>]
send_recv_async_args = [<class 'ctypes.c_void_p'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_int'>, <class 'ctypes.c_void_p'>]